From b3e859c2c6ae440f283417751f704ebb4e39b76b Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 30 May 2025 10:29:31 -0400 Subject: [PATCH 01/21] Add ExpandDims and Squeeze operations with tests and rewrite rules --- pytensor/xtensor/rewriting/shape.py | 66 +++++++++++++++- pytensor/xtensor/shape.py | 116 ++++++++++++++++++++++++++++ tests/xtensor/test_shape.py | 57 +++++++++++++- 3 files changed, 237 insertions(+), 2 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 84447670c2..2b3dd31ec4 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack +from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose, UnStack @register_xcanonicalize @@ -105,3 +105,67 @@ def lower_transpose(fgraph, node): x_tensor_transposed = x_tensor.transpose(perm) new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter([ExpandDims]) +def local_expand_dims_reshape(fgraph, node): + """Rewrite rule to convert expand_dims to reshape.""" + if not isinstance(node.op, ExpandDims): + return False + + x = node.inputs[0] + dim = node.op.dim + + # Create new dimensions list with the new dimension + new_dims = list(x.type.dims) + new_dims.append(dim) + + # Create new shape with the new dimension + new_shape = list(x.type.shape) + new_shape.append(1) + + # Create a new reshape operation + from pytensor.xtensor.shape import reshape + + return [reshape(x, new_shape, new_dims)] + + +@register_xcanonicalize +@node_rewriter([Squeeze]) +def local_squeeze_reshape(fgraph, node): + """Rewrite rule to convert squeeze to reshape.""" + if not isinstance(node.op, Squeeze): + return False + + x = node.inputs[0] + dim = node.op.dim + + # Get the index of the dimension to remove + if dim is not None: + if dim not in x.type.dims: + return False + dim_idx = x.type.dims.index(dim) + if x.type.shape[dim_idx] != 1: + return False + else: + # Find all dimensions of size 1 + dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1] + if not dim_idx: + return False + + # Create new dimensions and shape lists + new_dims = list(x.type.dims) + new_shape = list(x.type.shape) + if dim is not None: + new_dims.pop(dim_idx) + new_shape.pop(dim_idx) + else: + # Remove all dimensions of size 1 + new_dims = [d for i, d in enumerate(new_dims) if i not in dim_idx] + new_shape = [s for i, s in enumerate(new_shape) if i not in dim_idx] + + # Create a new reshape operation + from pytensor.xtensor.shape import reshape + + return [reshape(x, new_shape, new_dims)] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 38b702db84..951f56690f 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -301,3 +301,119 @@ def make_node(self, *inputs: Variable) -> Apply: def concat(xtensors, dim: str): return Concat(dim=dim)(*xtensors) + + +class ExpandDims(XOp): + """Add a new dimension to an XTensorVariable. + + Parameters + ---------- + dim : str or None + The name of the new dimension. If None, the dimension will be unnamed. + """ + + def __init__(self, dim): + self.dim = dim + + def make_node(self, x): + x = as_xtensor(x) + + # Check if dimension already exists + if self.dim is not None and self.dim in x.type.dims: + raise ValueError(f"Dimension {self.dim} already exists") + + # Create new dimensions list with the new dimension + new_dims = list(x.type.dims) + new_dims.append(self.dim) + + # Create new shape with the new dimension + new_shape = list(x.type.shape) + new_shape.append(1) + + output = xtensor( + dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + ) + return Apply(self, [x], [output]) + + +def expand_dims(x, dim: str): + """Add a new dimension to an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str + The name of the new dimension + + Returns + ------- + XTensorVariable + A new tensor with the expanded dimension + """ + return ExpandDims(dim=dim)(x) + + +class Squeeze(XOp): + """Remove a dimension of size 1 from an XTensorVariable. + + Parameters + ---------- + dim : str or None + The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + """ + + def __init__(self, dim=None): + self.dim = dim + + def make_node(self, x): + x = as_xtensor(x) + + # Get the index of the dimension to remove + if self.dim is not None: + if self.dim not in x.type.dims: + raise ValueError(f"Dimension {self.dim} not found") + dim_idx = x.type.dims.index(self.dim) + if x.type.shape[dim_idx] != 1: + raise ValueError( + f"Dimension {self.dim} has size {x.type.shape[dim_idx]}, not 1" + ) + else: + # Find all dimensions of size 1 + dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1] + if not dim_idx: + raise ValueError("No dimensions of size 1 to remove") + + # Create new dimensions and shape lists + new_dims = list(x.type.dims) + new_shape = list(x.type.shape) + if self.dim is not None: + new_dims.pop(dim_idx) + new_shape.pop(dim_idx) + else: + # Remove all dimensions of size 1 + new_dims = [d for i, d in enumerate(new_dims) if i not in dim_idx] + new_shape = [s for i, s in enumerate(new_shape) if i not in dim_idx] + + output = xtensor( + dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + ) + return Apply(self, [x], [output]) + + +def squeeze(x, dim=None): + """Remove a dimension of size 1 from an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str or None, optional + The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + + Returns + ------- + XTensorVariable + A new tensor with the specified dimension removed + """ + return Squeeze(dim=dim)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index fb185ae1ce..1e4705c30f 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -12,7 +12,7 @@ from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, stack, transpose, unstack +from pytensor.xtensor.shape import concat, expand_dims, squeeze, stack, transpose, unstack from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, @@ -256,3 +256,58 @@ def test_concat_scalar(): res = fn(x1_test, x2_test) expected_res = xr_concat([x1_test, x2_test], dim="new_dim") xr_assert_allclose(res, expected_res) + + +def test_expand_dims(): + # Test 1D tensor expansion + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, "country") + assert y.type.dims == ("city", "country") + assert y.type.shape == (3, 1) + + # Test 2D tensor expansion + x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 3)) + y2d = expand_dims(x2d, "batch") + assert y2d.type.dims == ("row", "col", "batch") + assert y2d.type.shape == (2, 3, 1) + + # Test expansion with different dimension name + z = expand_dims(x, "time") + assert z.type.dims == ("city", "time") + assert z.type.shape == (3, 1) + + # Test that expanding with an existing dimension raises an error + with pytest.raises(ValueError): + expand_dims(y, "city") + + # Test that expanding with None dimension works + z = expand_dims(x, None) + assert z.type.dims == ("city", None) + assert z.type.shape == (3, 1) + + +def test_squeeze(): + # Test squeezing a specific dimension + x = xtensor("x", dims=("city", "country"), shape=(3, 1)) + y = squeeze(x, "country") + assert y.type.dims == ("city",) + assert y.type.shape == (3,) + + # Test squeezing all dimensions of size 1 + x2d = xtensor("x2d", dims=("row", "col", "batch"), shape=(2, 1, 1)) + y2d = squeeze(x2d) + assert y2d.type.dims == ("row",) + assert y2d.type.shape == (2,) + + # Test that squeezing a non-existent dimension raises an error + with pytest.raises(ValueError): + squeeze(x, "time") + + # Test that squeezing a dimension of size > 1 raises an error + with pytest.raises(ValueError): + squeeze(x, "city") + + # Test that squeezing when no dimensions are of size 1 raises an error + x3d = xtensor("x3d", dims=("row", "col", "batch"), shape=(2, 3, 4)) + with pytest.raises(ValueError): + squeeze(x3d) From d82487065f0dbaf62591ad8daa8717edee08b4d0 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 30 May 2025 11:06:40 -0400 Subject: [PATCH 02/21] fix: match xarray's expand_dims behavior with None dimension - Update ExpandDims op and rewrite rule to not add a new dimension when dim is None - Update tests to verify behavior matches xarray --- pytensor/xtensor/rewriting/shape.py | 39 ++++++++------- pytensor/xtensor/shape.py | 24 +++++----- tests/xtensor/test_shape.py | 73 ++++++++++++++++++++++++----- 3 files changed, 96 insertions(+), 40 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 2b3dd31ec4..9279635110 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,5 +1,5 @@ from pytensor.graph import node_rewriter -from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape +from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape, squeeze, expand_dims from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose, UnStack @@ -110,31 +110,37 @@ def lower_transpose(fgraph, node): @register_xcanonicalize @node_rewriter([ExpandDims]) def local_expand_dims_reshape(fgraph, node): - """Rewrite rule to convert expand_dims to reshape.""" + """Rewrite rule to convert expand_dims to pytensor.tensor.expand_dims and broadcast_to if needed.""" if not isinstance(node.op, ExpandDims): return False x = node.inputs[0] dim = node.op.dim + size = getattr(node.op, 'size', 1) - # Create new dimensions list with the new dimension - new_dims = list(x.type.dims) - new_dims.append(dim) + # If dim is None, don't add a new dimension (matching xarray behavior) + if dim is None: + return [x] - # Create new shape with the new dimension - new_shape = list(x.type.shape) - new_shape.append(1) + # Create new dimensions list with the new dimension at the beginning + new_dims = [dim] + list(x.type.dims) - # Create a new reshape operation - from pytensor.xtensor.shape import reshape + # Create new shape with the new dimension at the beginning + new_shape = [1] + list(x.type.shape) - return [reshape(x, new_shape, new_dims)] + # Convert to tensor and use pytensor.tensor.expand_dims + x_tensor = tensor_from_xtensor(x) + x_tensor_expanded = expand_dims(x_tensor, axis=0) + if size != 1: + x_tensor_expanded = broadcast_to(x_tensor_expanded, new_shape) + new_out = xtensor_from_tensor(x_tensor_expanded, dims=tuple(new_dims)) + return [new_out] @register_xcanonicalize @node_rewriter([Squeeze]) def local_squeeze_reshape(fgraph, node): - """Rewrite rule to convert squeeze to reshape.""" + """Rewrite rule to convert squeeze to pytensor.tensor.squeeze.""" if not isinstance(node.op, Squeeze): return False @@ -165,7 +171,8 @@ def local_squeeze_reshape(fgraph, node): new_dims = [d for i, d in enumerate(new_dims) if i not in dim_idx] new_shape = [s for i, s in enumerate(new_shape) if i not in dim_idx] - # Create a new reshape operation - from pytensor.xtensor.shape import reshape - - return [reshape(x, new_shape, new_dims)] + # Convert to tensor and use pytensor.tensor.squeeze + x_tensor = tensor_from_xtensor(x) + x_tensor_squeezed = squeeze(x_tensor, axis=dim_idx if dim is None else dim_idx) + new_out = xtensor_from_tensor(x_tensor_squeezed, dims=tuple(new_dims)) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 951f56690f..c8babaf567 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -309,26 +309,28 @@ class ExpandDims(XOp): Parameters ---------- dim : str or None - The name of the new dimension. If None, the dimension will be unnamed. + The name of the new dimension. If None, no new dimension is added. + size : int or symbolic, optional + The size of the new dimension (default 1). """ - - def __init__(self, dim): + def __init__(self, dim, size=1): self.dim = dim + self.size = size def make_node(self, x): x = as_xtensor(x) + # If dim is None, don't add a new dimension (matching xarray behavior) + if self.dim is None: + return Apply(self, [x], [x]) + # Check if dimension already exists - if self.dim is not None and self.dim in x.type.dims: + if self.dim in x.type.dims: raise ValueError(f"Dimension {self.dim} already exists") - # Create new dimensions list with the new dimension - new_dims = list(x.type.dims) - new_dims.append(self.dim) - - # Create new shape with the new dimension - new_shape = list(x.type.shape) - new_shape.append(1) + # Add new dimension at the beginning + new_dims = [self.dim] + list(x.type.dims) + new_shape = [self.size] + list(x.type.shape) output = xtensor( dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 1e4705c30f..04a3eaeeab 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -259,31 +259,64 @@ def test_concat_scalar(): def test_expand_dims(): - # Test 1D tensor expansion + import xarray as xr + # 1D case + x_xr = xr.DataArray([0, 1, 2], dims=["city"]) + y_xr = x_xr.expand_dims("country") x = xtensor("x", dims=("city",), shape=(3,)) y = expand_dims(x, "country") - assert y.type.dims == ("city", "country") - assert y.type.shape == (3, 1) + assert y.type.dims == y_xr.dims + assert y.type.shape == y_xr.shape + fn = xr_function([x], y) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = x_test.expand_dims("country") + xr_assert_allclose(res, expected_res) - # Test 2D tensor expansion + # 2D case + x2d_xr = xr.DataArray([[0, 1, 2], [3, 4, 5]], dims=["row", "col"]) + y2d_xr = x2d_xr.expand_dims("batch") x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 3)) y2d = expand_dims(x2d, "batch") - assert y2d.type.dims == ("row", "col", "batch") - assert y2d.type.shape == (2, 3, 1) + assert y2d.type.dims == y2d_xr.dims + assert y2d.type.shape == y2d_xr.shape + fn = xr_function([x2d], y2d) + x2d_test = xr_arange_like(x2d) + res = fn(x2d_test) + expected_res = x2d_test.expand_dims("batch") + xr_assert_allclose(res, expected_res) - # Test expansion with different dimension name + # Expansion with different dimension name + z_xr = x_xr.expand_dims("time") z = expand_dims(x, "time") - assert z.type.dims == ("city", "time") - assert z.type.shape == (3, 1) + assert z.type.dims == z_xr.dims + assert z.type.shape == z_xr.shape + fn = xr_function([x], z) + res = fn(x_test) + expected_res = x_test.expand_dims("time") + xr_assert_allclose(res, expected_res) - # Test that expanding with an existing dimension raises an error + # Expanding with an existing dimension raises an error with pytest.raises(ValueError): expand_dims(y, "city") - # Test that expanding with None dimension works + # Expanding with None dimension + print("\nTesting expand_dims with None:") + print("Input xarray dims:", x_xr.dims) + print("Input xarray shape:", x_xr.shape) + z_xr = x_xr.expand_dims(None) + print("Output xarray dims:", z_xr.dims) + print("Output xarray shape:", z_xr.shape) + print("Output xarray data:\n", z_xr.data) z = expand_dims(x, None) - assert z.type.dims == ("city", None) - assert z.type.shape == (3, 1) + print("Our output dims:", z.type.dims) + print("Our output shape:", z.type.shape) + assert z.type.dims == z_xr.dims + assert z.type.shape == z_xr.shape + fn = xr_function([x], z) + res = fn(x_test) + expected_res = x_test.expand_dims(None) + xr_assert_allclose(res, expected_res) def test_squeeze(): @@ -293,12 +326,26 @@ def test_squeeze(): assert y.type.dims == ("city",) assert y.type.shape == (3,) + # Test with xarray + fn = xr_function([x], y) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = x_test.squeeze("country") + xr_assert_allclose(res, expected_res) + # Test squeezing all dimensions of size 1 x2d = xtensor("x2d", dims=("row", "col", "batch"), shape=(2, 1, 1)) y2d = squeeze(x2d) assert y2d.type.dims == ("row",) assert y2d.type.shape == (2,) + # Test with xarray + fn = xr_function([x2d], y2d) + x2d_test = xr_arange_like(x2d) + res = fn(x2d_test) + expected_res = x2d_test.squeeze() + xr_assert_allclose(res, expected_res) + # Test that squeezing a non-existent dimension raises an error with pytest.raises(ValueError): squeeze(x, "time") From a076966b77c9f77249c522eb69c155cd32373ca9 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 30 May 2025 11:41:32 -0400 Subject: [PATCH 03/21] Simplify local_squeeze_reshape function by removing redundant checks and streamlining logic. --- pytensor/xtensor/rewriting/shape.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 2b3dd31ec4..1578175aa9 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,14 @@ from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose, UnStack +from pytensor.xtensor.shape import ( + Concat, + ExpandDims, + Squeeze, + Stack, + Transpose, + UnStack, +) @register_xcanonicalize @@ -143,11 +150,7 @@ def local_squeeze_reshape(fgraph, node): # Get the index of the dimension to remove if dim is not None: - if dim not in x.type.dims: - return False dim_idx = x.type.dims.index(dim) - if x.type.shape[dim_idx] != 1: - return False else: # Find all dimensions of size 1 dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1] From e2ffe1c7ccabb15740ec0816a1694c54fb606b78 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 30 May 2025 11:46:50 -0400 Subject: [PATCH 04/21] Ruff --- tests/xtensor/test_shape.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 1e4705c30f..e6948fb434 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -12,7 +12,14 @@ from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, expand_dims, squeeze, stack, transpose, unstack +from pytensor.xtensor.shape import ( + concat, + expand_dims, + squeeze, + stack, + transpose, + unstack, +) from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, From 74894899137ffb732c28fa22b0bdbd6c5b984f05 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Mon, 2 Jun 2025 09:49:09 -0400 Subject: [PATCH 05/21] Restoring the commit with xarray-based tests --- pytensor/xtensor/rewriting/shape.py | 63 +++++++++++++++++----------- pytensor/xtensor/shape.py | 64 +++++++++++++++-------------- tests/xtensor/test_shape.py | 62 ++++++++++++++++++++-------- 3 files changed, 117 insertions(+), 72 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 9279635110..e1d3089a3f 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,8 +1,22 @@ from pytensor.graph import node_rewriter -from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape, squeeze, expand_dims +from pytensor.tensor import ( + broadcast_to, + expand_dims, + join, + moveaxis, + specify_shape, + squeeze, +) from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose, UnStack +from pytensor.xtensor.shape import ( + Concat, + ExpandDims, + Squeeze, + Stack, + Transpose, + UnStack, +) @register_xcanonicalize @@ -116,17 +130,17 @@ def local_expand_dims_reshape(fgraph, node): x = node.inputs[0] dim = node.op.dim - size = getattr(node.op, 'size', 1) + size = getattr(node.op, "size", 1) # If dim is None, don't add a new dimension (matching xarray behavior) if dim is None: return [x] # Create new dimensions list with the new dimension at the beginning - new_dims = [dim] + list(x.type.dims) + new_dims = [dim, *list(x.type.dims)] # Create new shape with the new dimension at the beginning - new_shape = [1] + list(x.type.shape) + new_shape = [1, *list(x.type.shape)] # Convert to tensor and use pytensor.tensor.expand_dims x_tensor = tensor_from_xtensor(x) @@ -147,32 +161,31 @@ def local_squeeze_reshape(fgraph, node): x = node.inputs[0] dim = node.op.dim - # Get the index of the dimension to remove - if dim is not None: - if dim not in x.type.dims: - return False - dim_idx = x.type.dims.index(dim) - if x.type.shape[dim_idx] != 1: - return False + # Convert single dimension to iterable for consistent handling + dims_to_remove = [dim] if isinstance(dim, str) else dim + + if dims_to_remove is not None: + # Validate dimensions exist and have size 1 + dim_indices = [] + for d in dims_to_remove: + if d not in x.type.dims: + return False + dim_idx = x.type.dims.index(d) + # Only check shape != 1 if the shape is not None (symbolic) + if x.type.shape[dim_idx] is not None and x.type.shape[dim_idx] != 1: + return False + dim_indices.append(dim_idx) else: # Find all dimensions of size 1 - dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1] - if not dim_idx: + dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1] + if not dim_indices: return False - # Create new dimensions and shape lists - new_dims = list(x.type.dims) - new_shape = list(x.type.shape) - if dim is not None: - new_dims.pop(dim_idx) - new_shape.pop(dim_idx) - else: - # Remove all dimensions of size 1 - new_dims = [d for i, d in enumerate(new_dims) if i not in dim_idx] - new_shape = [s for i, s in enumerate(new_shape) if i not in dim_idx] + # Create new dimensions list + new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices] # Convert to tensor and use pytensor.tensor.squeeze x_tensor = tensor_from_xtensor(x) - x_tensor_squeezed = squeeze(x_tensor, axis=dim_idx if dim is None else dim_idx) + x_tensor_squeezed = squeeze(x_tensor, axis=tuple(dim_indices)) new_out = xtensor_from_tensor(x_tensor_squeezed, dims=tuple(new_dims)) return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index c8babaf567..d7ee271dd4 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -313,6 +313,7 @@ class ExpandDims(XOp): size : int or symbolic, optional The size of the new dimension (default 1). """ + def __init__(self, dim, size=1): self.dim = dim self.size = size @@ -329,11 +330,13 @@ def make_node(self, x): raise ValueError(f"Dimension {self.dim} already exists") # Add new dimension at the beginning - new_dims = [self.dim] + list(x.type.dims) - new_shape = [self.size] + list(x.type.shape) + new_dims = [self.dim, *list(x.type.dims)] + new_shape = [self.size, *list(x.type.shape)] output = xtensor( - dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + dtype=x.type.dtype, + dims=tuple(new_dims), + shape=tuple(new_shape), ) return Apply(self, [x], [output]) @@ -357,12 +360,12 @@ def expand_dims(x, dim: str): class Squeeze(XOp): - """Remove a dimension of size 1 from an XTensorVariable. + """Remove dimensions of size 1 from an XTensorVariable. Parameters ---------- - dim : str or None - The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + dim : str or None or iterable of str + The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 will be removed. """ def __init__(self, dim=None): @@ -371,31 +374,32 @@ def __init__(self, dim=None): def make_node(self, x): x = as_xtensor(x) - # Get the index of the dimension to remove - if self.dim is not None: - if self.dim not in x.type.dims: - raise ValueError(f"Dimension {self.dim} not found") - dim_idx = x.type.dims.index(self.dim) - if x.type.shape[dim_idx] != 1: - raise ValueError( - f"Dimension {self.dim} has size {x.type.shape[dim_idx]}, not 1" - ) + # Convert single dimension to iterable for consistent handling + dims_to_remove = [self.dim] if isinstance(self.dim, str) else self.dim + + if dims_to_remove is not None: + # Validate dimensions exist and have size 1 + for dim in dims_to_remove: + if dim not in x.type.dims: + raise ValueError(f"Dimension {dim} not found") + dim_idx = x.type.dims.index(dim) + # Only raise an error if the shape is statically known and not 1. + # If the shape is None (symbolic), defer the error to runtime. + if x.type.shape[dim_idx] is not None and x.type.shape[dim_idx] != 1: + raise ValueError( + f"Dimension {dim} has size {x.type.shape[dim_idx]}, not 1" + ) + # Get indices of dimensions to remove + dim_indices = [x.type.dims.index(dim) for dim in dims_to_remove] else: # Find all dimensions of size 1 - dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1] - if not dim_idx: + dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1] + if not dim_indices: raise ValueError("No dimensions of size 1 to remove") # Create new dimensions and shape lists - new_dims = list(x.type.dims) - new_shape = list(x.type.shape) - if self.dim is not None: - new_dims.pop(dim_idx) - new_shape.pop(dim_idx) - else: - # Remove all dimensions of size 1 - new_dims = [d for i, d in enumerate(new_dims) if i not in dim_idx] - new_shape = [s for i, s in enumerate(new_shape) if i not in dim_idx] + new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices] + new_shape = [s for i, s in enumerate(x.type.shape) if i not in dim_indices] output = xtensor( dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) @@ -404,18 +408,18 @@ def make_node(self, x): def squeeze(x, dim=None): - """Remove a dimension of size 1 from an XTensorVariable. + """Remove dimensions of size 1 from an XTensorVariable. Parameters ---------- x : XTensorVariable The input tensor - dim : str or None, optional - The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + dim : str or None or iterable of str, optional + The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 will be removed. Returns ------- XTensorVariable - A new tensor with the specified dimension removed + A new tensor with the specified dimension(s) removed """ return Squeeze(dim=dim)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 04a3eaeeab..17516495c2 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -12,7 +12,14 @@ from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, expand_dims, squeeze, stack, transpose, unstack +from pytensor.xtensor.shape import ( + concat, + expand_dims, + squeeze, + stack, + transpose, + unstack, +) from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, @@ -260,6 +267,7 @@ def test_concat_scalar(): def test_expand_dims(): import xarray as xr + # 1D case x_xr = xr.DataArray([0, 1, 2], dims=["city"]) y_xr = x_xr.expand_dims("country") @@ -301,16 +309,8 @@ def test_expand_dims(): expand_dims(y, "city") # Expanding with None dimension - print("\nTesting expand_dims with None:") - print("Input xarray dims:", x_xr.dims) - print("Input xarray shape:", x_xr.shape) z_xr = x_xr.expand_dims(None) - print("Output xarray dims:", z_xr.dims) - print("Output xarray shape:", z_xr.shape) - print("Output xarray data:\n", z_xr.data) z = expand_dims(x, None) - print("Our output dims:", z.type.dims) - print("Our output shape:", z.type.shape) assert z.type.dims == z_xr.dims assert z.type.shape == z_xr.shape fn = xr_function([x], z) @@ -323,23 +323,51 @@ def test_squeeze(): # Test squeezing a specific dimension x = xtensor("x", dims=("city", "country"), shape=(3, 1)) y = squeeze(x, "country") - assert y.type.dims == ("city",) - assert y.type.shape == (3,) - - # Test with xarray fn = xr_function([x], y) x_test = xr_arange_like(x) res = fn(x_test) expected_res = x_test.squeeze("country") xr_assert_allclose(res, expected_res) + # Test squeezing multiple specific dimensions + x_multi = xtensor("x_multi", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) + y_multi = squeeze(x_multi, ["b", "c"]) + fn = xr_function([x_multi], y_multi) + x_multi_test = xr_arange_like(x_multi) + res = fn(x_multi_test) + expected_res = x_multi_test.squeeze(["b", "c"]) + xr_assert_allclose(res, expected_res) + + # Test squeezing a non-last dimension + x_nonlast = xtensor("x_nonlast", dims=("a", "b", "c"), shape=(2, 1, 3)) + y_nonlast = squeeze(x_nonlast, "b") + fn = xr_function([x_nonlast], y_nonlast) + x_nonlast_test = xr_arange_like(x_nonlast) + res = fn(x_nonlast_test) + expected_res = x_nonlast_test.squeeze("b") + xr_assert_allclose(res, expected_res) + + # Test squeezing in a higher-dimensional tensor + x_high = xtensor("x_high", dims=("a", "b", "c", "d", "e"), shape=(2, 1, 3, 1, 4)) + y_high = squeeze(x_high, ["b", "d"]) + fn = xr_function([x_high], y_high) + x_high_test = xr_arange_like(x_high) + res = fn(x_high_test) + expected_res = x_high_test.squeeze(["b", "d"]) + xr_assert_allclose(res, expected_res) + + # Test with symbolic shapes + x_sym = xtensor("x_sym", dims=("a", "b", "c")) + y_sym = squeeze(x_sym, "b") + x_sym_test = xr_arange_like(xtensor(dims=x_sym.dims, shape=(2, 1, 3))) + fn = xr_function([x_sym], y_sym) + res = fn(x_sym_test) + expected_res = x_sym_test.squeeze("b") + xr_assert_allclose(res, expected_res) + # Test squeezing all dimensions of size 1 x2d = xtensor("x2d", dims=("row", "col", "batch"), shape=(2, 1, 1)) y2d = squeeze(x2d) - assert y2d.type.dims == ("row",) - assert y2d.type.shape == (2,) - - # Test with xarray fn = xr_function([x2d], y2d) x2d_test = xr_arange_like(x2d) res = fn(x2d_test) From a7e2bf84946fe9a5239630b447e10b19f50ffb7b Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Mon, 2 Jun 2025 13:33:10 -0400 Subject: [PATCH 06/21] Updating squeeze --- pytensor/xtensor/rewriting/shape.py | 40 +++++++---------- pytensor/xtensor/shape.py | 61 ++++++++++++++++---------- tests/xtensor/test_shape.py | 68 +++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 47 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index e1d3089a3f..8503c8b674 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -154,38 +154,28 @@ def local_expand_dims_reshape(fgraph, node): @register_xcanonicalize @node_rewriter([Squeeze]) def local_squeeze_reshape(fgraph, node): - """Rewrite rule to convert squeeze to pytensor.tensor.squeeze.""" + """Rewrite rule to convert Squeeze to pytensor.tensor.squeeze.""" if not isinstance(node.op, Squeeze): return False - x = node.inputs[0] + [x] = node.inputs + in_dims = x.type.dims dim = node.op.dim - # Convert single dimension to iterable for consistent handling - dims_to_remove = [dim] if isinstance(dim, str) else dim - - if dims_to_remove is not None: - # Validate dimensions exist and have size 1 - dim_indices = [] - for d in dims_to_remove: - if d not in x.type.dims: - return False - dim_idx = x.type.dims.index(d) - # Only check shape != 1 if the shape is not None (symbolic) - if x.type.shape[dim_idx] is not None and x.type.shape[dim_idx] != 1: - return False - dim_indices.append(dim_idx) + # Determine which axes to squeeze + if dim is None: + # Infer axes by comparing input and output dims + out_dims = node.outputs[0].type.dims + axes_to_squeeze = tuple(i for i, d in enumerate(in_dims) if d not in out_dims) else: - # Find all dimensions of size 1 - dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1] - if not dim_indices: - return False + dims_to_remove = [dim] if isinstance(dim, str) else dim + axes_to_squeeze = tuple(in_dims.index(d) for d in dims_to_remove) - # Create new dimensions list - new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices] + # Nothing to squeeze? Just return input unchanged + if not axes_to_squeeze: + return [x] - # Convert to tensor and use pytensor.tensor.squeeze x_tensor = tensor_from_xtensor(x) - x_tensor_squeezed = squeeze(x_tensor, axis=tuple(dim_indices)) - new_out = xtensor_from_tensor(x_tensor_squeezed, dims=tuple(new_dims)) + x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze) + new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index d7ee271dd4..81b240e553 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -364,45 +364,62 @@ class Squeeze(XOp): Parameters ---------- - dim : str or None or iterable of str - The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 will be removed. + dim : str, None, or iterable of str + The name(s) of the dimension(s) to remove. If None, all dimensions + that are statically known to have size 1 will be removed. + Dimensions with symbolic shape will not be removed unless explicitly named. + + Note: Unlike NumPy/xarray, if dim is None, only dimensions known to + be size 1 at graph construction time will be removed, even if they happen + to be size 1 at runtime. """ + __props__ = ("dim",) + def __init__(self, dim=None): - self.dim = dim + if dim is None: + self.dim = None + else: + dims = [dim] if isinstance(dim, str) else dim + if not all(isinstance(d, str) for d in dims): + raise TypeError(f"All dimension names must be strings: got {dims}") + # Deduplicate and sort to make __props__ deterministic and hashable + self.dim = tuple(sorted(set(dims))) + + if not self.dim: + warnings.warn( + "Squeeze received an empty dim list — no dimensions will be removed." + ) def make_node(self, x): x = as_xtensor(x) - # Convert single dimension to iterable for consistent handling - dims_to_remove = [self.dim] if isinstance(self.dim, str) else self.dim + if self.dim is None: + # Auto-detect static size-1 dimensions + dims_to_remove = [d for d, s in zip(x.type.dims, x.type.shape) if s == 1] + if not dims_to_remove: + raise ValueError("No dimensions of size 1 to remove") + else: + dims_to_remove = list(self.dim) - if dims_to_remove is not None: - # Validate dimensions exist and have size 1 + # Validate existence and static shape (when possible) for dim in dims_to_remove: if dim not in x.type.dims: raise ValueError(f"Dimension {dim} not found") dim_idx = x.type.dims.index(dim) - # Only raise an error if the shape is statically known and not 1. - # If the shape is None (symbolic), defer the error to runtime. - if x.type.shape[dim_idx] is not None and x.type.shape[dim_idx] != 1: - raise ValueError( - f"Dimension {dim} has size {x.type.shape[dim_idx]}, not 1" - ) - # Get indices of dimensions to remove - dim_indices = [x.type.dims.index(dim) for dim in dims_to_remove] - else: - # Find all dimensions of size 1 - dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1] - if not dim_indices: - raise ValueError("No dimensions of size 1 to remove") + shape = x.type.shape[dim_idx] + if shape is not None and shape != 1: + raise ValueError(f"Dimension {dim} has size {shape}, not 1") + + dim_indices = [x.type.dims.index(dim) for dim in dims_to_remove] - # Create new dimensions and shape lists new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices] new_shape = [s for i, s in enumerate(x.type.shape) if i not in dim_indices] output = xtensor( - dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + dtype=x.type.dtype, + shape=tuple(new_shape), + dims=tuple(new_dims), ) return Apply(self, [x], [output]) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 17516495c2..3ba9e7e182 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -386,3 +386,71 @@ def test_squeeze(): x3d = xtensor("x3d", dims=("row", "col", "batch"), shape=(2, 3, 4)) with pytest.raises(ValueError): squeeze(x3d) + + +def test_squeeze_additional_cases(): + # Redundant dimensions: squeeze(["b", "b"]) should behave like squeeze(["b"]) + x1 = xtensor("x1", dims=("a", "b", "c"), shape=(2, 1, 1)) + y1 = squeeze(x1, ["b", "b"]) + fn1 = xr_function([x1], y1) + x1_test = xr_arange_like(x1) + expected1 = x1_test.squeeze(["b"]) + xr_assert_allclose(fn1(x1_test), expected1) + + # Symbolic shape: dim is 1 at runtime → should squeeze successfully + x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown + y2 = squeeze(x2, "b") + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 1, 3))) + expected2 = x2_test.squeeze("b") + xr_assert_allclose(fn2(x2_test), expected2) + + # Symbolic shape: dim is not 1 at runtime → should raise + x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown + y3 = squeeze(x3, "b") + fn3 = xr_function([x3], y3) + x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 2, 3))) + with pytest.raises(Exception): + fn3(x3_test) + + # Reversibility: squeeze then expand_dims should restore original + # TODO: uncomment when we have expand_dims + # x4 = xtensor("x4", dims=("batch", "time", "feature"), shape=(2, 1, 3)) + # y4 = squeeze(x4, "time") + # z4 = expand_dims(y4, "time") + # fn4 = xr_function([x4], z4) + # x4_test = xr_arange_like(x4) + # xr_assert_allclose(fn4(x4_test), x4_test) + + +def test_squeeze_extra_cases(): + # 1. Order of dims shouldn't affect result + x1 = xtensor("x1", dims=("a", "b", "c"), shape=(2, 1, 1)) + y1 = squeeze(x1, ["b", "c"]) + y2 = squeeze(x1, ["c", "b"]) + fn1 = xr_function([x1], y1) + fn2 = xr_function([x1], y2) + x1_test = xr_arange_like(x1) + xr_assert_allclose(fn1(x1_test), fn2(x1_test)) + + # 2. Empty list of dims = no-op + x2 = xtensor("x2", dims=("a", "b", "c"), shape=(2, 1, 1)) + y2 = squeeze(x2, []) + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn2(x2_test), x2_test) + + # 3. Explicit squeeze of all size-1 dims via dim=None + x3 = xtensor("x3", dims=("a", "b"), shape=(1, 1)) + y3 = squeeze(x3) + fn3 = xr_function([x3], y3) + x3_test = xr_arange_like(x3) + xr_assert_allclose(fn3(x3_test), x3_test.squeeze()) + + # 4. Static + symbolic shape mix: squeeze symbolic 1-sized dim + x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3)) + y4 = squeeze(x4, "b") + x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3))) + fn4 = xr_function([x4], y4) + expected4 = x4_test.squeeze("b") + xr_assert_allclose(fn4(x4_test), expected4) From 332139dfc68e231523dcc21ab3a16f8a2748312d Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Mon, 2 Jun 2025 15:19:16 -0400 Subject: [PATCH 07/21] Working on expand_dims --- pytensor/xtensor/rewriting/shape.py | 26 ++--- pytensor/xtensor/shape.py | 49 ++++++--- tests/xtensor/test_shape.py | 149 +++++++++++++++++++++------- 3 files changed, 161 insertions(+), 63 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 8503c8b674..fc7b75f6c5 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,3 +1,5 @@ +import numpy as np + from pytensor.graph import node_rewriter from pytensor.tensor import ( broadcast_to, @@ -124,7 +126,7 @@ def lower_transpose(fgraph, node): @register_xcanonicalize @node_rewriter([ExpandDims]) def local_expand_dims_reshape(fgraph, node): - """Rewrite rule to convert expand_dims to pytensor.tensor.expand_dims and broadcast_to if needed.""" + """Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape.""" if not isinstance(node.op, ExpandDims): return False @@ -132,22 +134,22 @@ def local_expand_dims_reshape(fgraph, node): dim = node.op.dim size = getattr(node.op, "size", 1) - # If dim is None, don't add a new dimension (matching xarray behavior) if dim is None: return [x] - # Create new dimensions list with the new dimension at the beginning - new_dims = [dim, *list(x.type.dims)] - - # Create new shape with the new dimension at the beginning - new_shape = [1, *list(x.type.shape)] - - # Convert to tensor and use pytensor.tensor.expand_dims x_tensor = tensor_from_xtensor(x) x_tensor_expanded = expand_dims(x_tensor, axis=0) - if size != 1: - x_tensor_expanded = broadcast_to(x_tensor_expanded, new_shape) - new_out = xtensor_from_tensor(x_tensor_expanded, dims=tuple(new_dims)) + + target_shape = node.outputs[0].type.shape + + if isinstance(size, int | np.integer): + if size != 1 and None not in target_shape: + x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape) + else: + # Symbolic size: enforce shape so broadcast happens downstream correctly + x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape) + + new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims) return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 81b240e553..5fc9ea463e 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -2,6 +2,8 @@ from collections.abc import Sequence from typing import Literal +import numpy as np + from pytensor import Variable from pytensor.graph import Apply from pytensor.scalar import discrete_dtypes, upcast @@ -314,49 +316,70 @@ class ExpandDims(XOp): The size of the new dimension (default 1). """ + __props__ = ("dim", "size") + def __init__(self, dim, size=1): + if dim is not None and not isinstance(dim, str): + raise TypeError(f"`dim` must be a string or None, got: {type(dim)}") + if isinstance(size, int | np.integer) and size <= 0: + raise ValueError(f"size must be positive, got: {size}") self.dim = dim self.size = size def make_node(self, x): x = as_xtensor(x) - # If dim is None, don't add a new dimension (matching xarray behavior) if self.dim is None: return Apply(self, [x], [x]) - # Check if dimension already exists if self.dim in x.type.dims: - raise ValueError(f"Dimension {self.dim} already exists") + raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}") - # Add new dimension at the beginning - new_dims = [self.dim, *list(x.type.dims)] - new_shape = [self.size, *list(x.type.shape)] + # Handle scalar case + if not x.type.dims: + new_dims = (self.dim,) + new_shape = (self.size,) + else: + # Use symbolic shape + new_dims = (self.dim, *x.type.dims) + if isinstance(self.size, int | np.integer): + new_shape = (self.size, *x.type.shape) + else: + # For symbolic size, we need to use a symbolic shape + new_shape = (None, *x.type.shape) - output = xtensor( + out = xtensor( dtype=x.type.dtype, - dims=tuple(new_dims), - shape=tuple(new_shape), + shape=new_shape, + dims=new_dims, ) - return Apply(self, [x], [output]) + return Apply(self, [x], [out]) + + def infer_shape(self, fgraph, node, input_shapes): + (input_shape,) = input_shapes + if self.dim is None: + return [input_shape] + return [(self.size, *list(input_shape))] -def expand_dims(x, dim: str): +def expand_dims(x, dim: str | None, size=1): """Add a new dimension to an XTensorVariable. Parameters ---------- x : XTensorVariable The input tensor - dim : str + dim : str or None The name of the new dimension + size : int or symbolic, optional + The size of the new dimension (default 1) Returns ------- XTensorVariable A new tensor with the expanded dimension """ - return ExpandDims(dim=dim)(x) + return ExpandDims(dim=dim, size=size)(x) class Squeeze(XOp): diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 3ba9e7e182..67b558f010 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -1,17 +1,14 @@ # ruff: noqa: E402 import re - -import pytest - - -pytest.importorskip("xarray") - from itertools import chain, combinations import numpy as np +import pytest +import xarray as xr from xarray import DataArray from xarray import concat as xr_concat +from pytensor.tensor import scalar from pytensor.xtensor.shape import ( concat, expand_dims, @@ -29,6 +26,9 @@ ) +pytest.importorskip("xarray") + + def powerset(iterable, min_group_size=0): "Subsequences of the iterable from shortest to longest." # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) @@ -265,58 +265,131 @@ def test_concat_scalar(): xr_assert_allclose(res, expected_res) -def test_expand_dims(): - import xarray as xr +def assert_dims_and_shape(actual, expected): + assert actual.type.dims == expected.dims + assert actual.type.shape == expected.shape + +def test_expand_dims(): # 1D case x_xr = xr.DataArray([0, 1, 2], dims=["city"]) y_xr = x_xr.expand_dims("country") x = xtensor("x", dims=("city",), shape=(3,)) y = expand_dims(x, "country") - assert y.type.dims == y_xr.dims - assert y.type.shape == y_xr.shape + assert_dims_and_shape(y, y_xr) fn = xr_function([x], y) x_test = xr_arange_like(x) - res = fn(x_test) - expected_res = x_test.expand_dims("country") - xr_assert_allclose(res, expected_res) + xr_assert_allclose(fn(x_test), y_xr) # 2D case x2d_xr = xr.DataArray([[0, 1, 2], [3, 4, 5]], dims=["row", "col"]) y2d_xr = x2d_xr.expand_dims("batch") x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 3)) y2d = expand_dims(x2d, "batch") - assert y2d.type.dims == y2d_xr.dims - assert y2d.type.shape == y2d_xr.shape + assert_dims_and_shape(y2d, y2d_xr) fn = xr_function([x2d], y2d) x2d_test = xr_arange_like(x2d) - res = fn(x2d_test) - expected_res = x2d_test.expand_dims("batch") - xr_assert_allclose(res, expected_res) + xr_assert_allclose(fn(x2d_test), y2d_xr) # Expansion with different dimension name z_xr = x_xr.expand_dims("time") z = expand_dims(x, "time") - assert z.type.dims == z_xr.dims - assert z.type.shape == z_xr.shape + assert_dims_and_shape(z, z_xr) fn = xr_function([x], z) - res = fn(x_test) - expected_res = x_test.expand_dims("time") - xr_assert_allclose(res, expected_res) + xr_assert_allclose(fn(x_test), z_xr) # Expanding with an existing dimension raises an error - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="already exists"): expand_dims(y, "city") - # Expanding with None dimension - z_xr = x_xr.expand_dims(None) + # Expanding with None dimension should return the same variable (no-op) z = expand_dims(x, None) - assert z.type.dims == z_xr.dims - assert z.type.shape == z_xr.shape + assert z is x + + # Test prepending different dimension names + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + for new_dim in ("x", "y", "z"): + y = expand_dims(x, new_dim) + expected_dims = (new_dim, *x.type.dims) + expected_shape = (1, *x.type.shape) + assert y.type.dims == expected_dims + assert y.type.shape == expected_shape + + # Explicit size=1 behaves like implicit broadcast + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + x_test = xr_arange_like(x) + y1 = expand_dims(x, "batch", size=1) + y2 = expand_dims(x, "batch") + fn1 = xr_function([x], y1) + fn2 = xr_function([x], y2) + xr_assert_allclose(fn1(x_test), fn2(x_test)) + + # Expanding with size=0 raises + with pytest.raises(ValueError, match="size must be.*positive"): + expand_dims(x, "batch", size=0) + + +def test_expand_dims_additional_cases(): + # Expanding a scalar + x = xtensor("x", dims=(), shape=()) + y = expand_dims(x, "batch") + assert y.type.dims == ("batch",) + assert y.type.shape == (1,) + fn = xr_function([x], y) + x_test = xr_arange_like(x) + expected = x_test.expand_dims("batch") + xr_assert_allclose(fn(x_test), expected) + + # Expanding with a specified static size > 1 + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=4) + assert y.type.dims == ("batch", "a", "b") + assert y.type.shape == (4, 2, 3) + fn = xr_function([x], y) + x_test = xr_arange_like(x) + expected = xr.DataArray( + np.broadcast_to(x_test.data, (4, 2, 3)), + dims=("batch", "a", "b"), + coords={"a": x_test.coords["a"], "b": x_test.coords["b"]}, + ) + xr_assert_allclose(fn(x_test), expected) + + # Expanding with symbolic size = 1 (no broadcast) + size_sym_1 = scalar("size_sym_1", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym_1) + fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") + x_test = xr_arange_like(x) + expected = x_test.expand_dims("batch") + xr_assert_allclose(fn(x_test, 1), expected) + + # Expanding with symbolic size > 1 (but no broadcast expected) + size_sym_4 = scalar("size_sym_4", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym_4) + fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") + x_test = xr_arange_like(x) + + # Even if symbolic size is 4, expand_dims will only insert dim=1 + expected = x_test.expand_dims("batch") + xr_assert_allclose(fn(x_test, 4), expected) + + # Reversibility: expand_dims then squeeze + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch") + z = squeeze(y, "batch") fn = xr_function([x], z) - res = fn(x_test) - expected_res = x_test.expand_dims(None) - xr_assert_allclose(res, expected_res) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test) + + # Expand with dim=None is a no-op + x = xtensor("x", dims=("a",), shape=(3,)) + y = expand_dims(x, None) + assert y.type.dims == x.type.dims + assert y.type.shape == x.type.shape + fn = xr_function([x], y) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test) def test_squeeze(): @@ -414,13 +487,13 @@ def test_squeeze_additional_cases(): fn3(x3_test) # Reversibility: squeeze then expand_dims should restore original - # TODO: uncomment when we have expand_dims - # x4 = xtensor("x4", dims=("batch", "time", "feature"), shape=(2, 1, 3)) - # y4 = squeeze(x4, "time") - # z4 = expand_dims(y4, "time") - # fn4 = xr_function([x4], z4) - # x4_test = xr_arange_like(x4) - # xr_assert_allclose(fn4(x4_test), x4_test) + x4 = xtensor("x4", dims=("batch", "time", "feature"), shape=(2, 1, 3)) + y4 = squeeze(x4, "time") + z4 = expand_dims(y4, "time") + fn4 = xr_function([x4], z4) + x4_test = xr_arange_like(x4) + # Adjust dimension order for comparison + xr_assert_allclose(fn4(x4_test).transpose(*x4_test.dims), x4_test) def test_squeeze_extra_cases(): From 4b2f0f7d128d10a4e929f0ebf261913d3b77daf7 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 3 Jun 2025 16:57:25 -0400 Subject: [PATCH 08/21] Working on squeeze --- pytensor/xtensor/rewriting/shape.py | 4 +- pytensor/xtensor/shape.py | 117 +++++++------- tests/xtensor/test_shape.py | 238 ++++++++++++---------------- 3 files changed, 160 insertions(+), 199 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index fc7b75f6c5..bc374046cd 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -131,7 +131,7 @@ def local_expand_dims_reshape(fgraph, node): return False x = node.inputs[0] - dim = node.op.dim + dim = node.op.dims size = getattr(node.op, "size", 1) if dim is None: @@ -162,7 +162,7 @@ def local_squeeze_reshape(fgraph, node): [x] = node.inputs in_dims = x.type.dims - dim = node.op.dim + dim = node.op.dims # Determine which axes to squeeze if dim is None: diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 5fc9ea463e..70ca5e82c6 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -316,32 +316,32 @@ class ExpandDims(XOp): The size of the new dimension (default 1). """ - __props__ = ("dim", "size") + __props__ = ("dims", "size") def __init__(self, dim, size=1): if dim is not None and not isinstance(dim, str): raise TypeError(f"`dim` must be a string or None, got: {type(dim)}") if isinstance(size, int | np.integer) and size <= 0: raise ValueError(f"size must be positive, got: {size}") - self.dim = dim + self.dims = dim self.size = size def make_node(self, x): x = as_xtensor(x) - if self.dim is None: + if self.dims is None: return Apply(self, [x], [x]) - if self.dim in x.type.dims: - raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}") + if self.dims in x.type.dims: + raise ValueError(f"Dimension {self.dims} already exists in {x.type.dims}") # Handle scalar case if not x.type.dims: - new_dims = (self.dim,) + new_dims = (self.dims,) new_shape = (self.size,) else: # Use symbolic shape - new_dims = (self.dim, *x.type.dims) + new_dims = (self.dims, *x.type.dims) if isinstance(self.size, int | np.integer): new_shape = (self.size, *x.type.shape) else: @@ -357,7 +357,7 @@ def make_node(self, x): def infer_shape(self, fgraph, node, input_shapes): (input_shape,) = input_shapes - if self.dim is None: + if self.dims is None: return [input_shape] return [(self.size, *list(input_shape))] @@ -383,68 +383,49 @@ def expand_dims(x, dim: str | None, size=1): class Squeeze(XOp): - """Remove dimensions of size 1 from an XTensorVariable. + """Remove specified dimensions from an XTensorVariable. + + Only dimensions that are known statically to be size 1 will be removed. + Symbolic dimensions must be explicitly specified, and are assumed safe. Parameters ---------- - dim : str, None, or iterable of str - The name(s) of the dimension(s) to remove. If None, all dimensions - that are statically known to have size 1 will be removed. - Dimensions with symbolic shape will not be removed unless explicitly named. - - Note: Unlike NumPy/xarray, if dim is None, only dimensions known to - be size 1 at graph construction time will be removed, even if they happen - to be size 1 at runtime. + dim : tuple of str + The names of the dimensions to remove. """ - __props__ = ("dim",) + __props__ = ("dims",) - def __init__(self, dim=None): - if dim is None: - self.dim = None - else: - dims = [dim] if isinstance(dim, str) else dim - if not all(isinstance(d, str) for d in dims): - raise TypeError(f"All dimension names must be strings: got {dims}") - # Deduplicate and sort to make __props__ deterministic and hashable - self.dim = tuple(sorted(set(dims))) - - if not self.dim: - warnings.warn( - "Squeeze received an empty dim list — no dimensions will be removed." - ) + def __init__(self, dim): + self.dims = dim def make_node(self, x): x = as_xtensor(x) - if self.dim is None: - # Auto-detect static size-1 dimensions - dims_to_remove = [d for d, s in zip(x.type.dims, x.type.shape) if s == 1] - if not dims_to_remove: - raise ValueError("No dimensions of size 1 to remove") - else: - dims_to_remove = list(self.dim) - - # Validate existence and static shape (when possible) - for dim in dims_to_remove: - if dim not in x.type.dims: - raise ValueError(f"Dimension {dim} not found") - dim_idx = x.type.dims.index(dim) - shape = x.type.shape[dim_idx] - if shape is not None and shape != 1: - raise ValueError(f"Dimension {dim} has size {shape}, not 1") - - dim_indices = [x.type.dims.index(dim) for dim in dims_to_remove] - - new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices] - new_shape = [s for i, s in enumerate(x.type.shape) if i not in dim_indices] + # Validate that dims exist and are size-1 if statically known + dims_to_remove = [] + for d in self.dims: + if d not in x.type.dims: + raise ValueError(f"Dimension {d} not found in {x.type.dims}") + idx = x.type.dims.index(d) + dim_size = x.type.shape[idx] + if dim_size is not None and dim_size != 1: + raise ValueError(f"Dimension {d} has static size {dim_size}, not 1") + dims_to_remove.append(idx) + + new_dims = tuple( + d for i, d in enumerate(x.type.dims) if i not in dims_to_remove + ) + new_shape = tuple( + s for i, s in enumerate(x.type.shape) if i not in dims_to_remove + ) - output = xtensor( + out = xtensor( dtype=x.type.dtype, - shape=tuple(new_shape), - dims=tuple(new_dims), + shape=new_shape, + dims=new_dims, ) - return Apply(self, [x], [output]) + return Apply(self, [x], [out]) def squeeze(x, dim=None): @@ -455,11 +436,27 @@ def squeeze(x, dim=None): x : XTensorVariable The input tensor dim : str or None or iterable of str, optional - The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 will be removed. + The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 + (known statically) will be removed. Dimensions with symbolic shape will be retained. Returns ------- XTensorVariable - A new tensor with the specified dimension(s) removed + A new tensor with the specified dimension(s) removed. """ - return Squeeze(dim=dim)(x) + x = as_xtensor(x) + + if dim is None: + dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1) + elif isinstance(dim, str): + dims = (dim,) + else: + dims = tuple(dim) + + # Normalize: deduplicate and sort + dims = tuple(sorted(set(dims))) + + if not dims: + return x # no-op if nothing to squeeze + + return Squeeze(dim=dims)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 67b558f010..d2d24cee87 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -265,38 +265,42 @@ def test_concat_scalar(): xr_assert_allclose(res, expected_res) -def assert_dims_and_shape(actual, expected): - assert actual.type.dims == expected.dims - assert actual.type.shape == expected.shape - - def test_expand_dims(): # 1D case x_xr = xr.DataArray([0, 1, 2], dims=["city"]) y_xr = x_xr.expand_dims("country") x = xtensor("x", dims=("city",), shape=(3,)) y = expand_dims(x, "country") - assert_dims_and_shape(y, y_xr) fn = xr_function([x], y) - x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), y_xr) + xr_assert_allclose(fn(x_xr), y_xr) # 2D case - x2d_xr = xr.DataArray([[0, 1, 2], [3, 4, 5]], dims=["row", "col"]) - y2d_xr = x2d_xr.expand_dims("batch") - x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 3)) - y2d = expand_dims(x2d, "batch") - assert_dims_and_shape(y2d, y2d_xr) - fn = xr_function([x2d], y2d) - x2d_test = xr_arange_like(x2d) - xr_assert_allclose(fn(x2d_test), y2d_xr) + x_xr = xr.DataArray([[0, 1], [2, 3]], dims=["city", "year"]) + y_xr = x_xr.expand_dims("country") + x = xtensor("x", dims=("city", "year"), shape=(2, 2)) + y = expand_dims(x, "country") + fn = xr_function([x], y) + xr_assert_allclose(fn(x_xr), y_xr) - # Expansion with different dimension name - z_xr = x_xr.expand_dims("time") - z = expand_dims(x, "time") - assert_dims_and_shape(z, z_xr) + # 3D case + x_xr = xr.DataArray( + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dims=["city", "year", "month"] + ) + y_xr = x_xr.expand_dims("country") + x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) + y = expand_dims(x, "country") + fn = xr_function([x], y) + xr_assert_allclose(fn(x_xr), y_xr) + + # Test that expand_dims is reversible with squeeze + x_xr = xr.DataArray([0, 1, 2], dims=["city"]) + y_xr = x_xr.expand_dims("country") + z_xr = y_xr.squeeze("country") + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, "country") + z = squeeze(y, "country") fn = xr_function([x], z) - xr_assert_allclose(fn(x_test), z_xr) + xr_assert_allclose(fn(x_xr), z_xr) # Expanding with an existing dimension raises an error with pytest.raises(ValueError, match="already exists"): @@ -328,8 +332,6 @@ def test_expand_dims(): with pytest.raises(ValueError, match="size must be.*positive"): expand_dims(x, "batch", size=0) - -def test_expand_dims_additional_cases(): # Expanding a scalar x = xtensor("x", dims=(), shape=()) y = expand_dims(x, "batch") @@ -393,137 +395,99 @@ def test_expand_dims_additional_cases(): def test_squeeze(): - # Test squeezing a specific dimension + # Basic squeeze x = xtensor("x", dims=("city", "country"), shape=(3, 1)) y = squeeze(x, "country") fn = xr_function([x], y) x_test = xr_arange_like(x) - res = fn(x_test) - expected_res = x_test.squeeze("country") - xr_assert_allclose(res, expected_res) + expected = x_test.squeeze("country") + xr_assert_allclose(fn(x_test), expected) - # Test squeezing multiple specific dimensions + # Multiple dims x_multi = xtensor("x_multi", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) y_multi = squeeze(x_multi, ["b", "c"]) fn = xr_function([x_multi], y_multi) - x_multi_test = xr_arange_like(x_multi) - res = fn(x_multi_test) - expected_res = x_multi_test.squeeze(["b", "c"]) - xr_assert_allclose(res, expected_res) + x_test = xr_arange_like(x_multi) + xr_assert_allclose(fn(x_test), x_test.squeeze(["b", "c"])) - # Test squeezing a non-last dimension - x_nonlast = xtensor("x_nonlast", dims=("a", "b", "c"), shape=(2, 1, 3)) - y_nonlast = squeeze(x_nonlast, "b") - fn = xr_function([x_nonlast], y_nonlast) - x_nonlast_test = xr_arange_like(x_nonlast) - res = fn(x_nonlast_test) - expected_res = x_nonlast_test.squeeze("b") - xr_assert_allclose(res, expected_res) + # All dims size 1 + x2d = xtensor("x2d", dims=("row", "col", "batch"), shape=(2, 1, 1)) + y2d = squeeze(x2d) + fn = xr_function([x2d], y2d) + x_test = xr_arange_like(x2d) + xr_assert_allclose(fn(x_test), x_test.squeeze()) - # Test squeezing in a higher-dimensional tensor - x_high = xtensor("x_high", dims=("a", "b", "c", "d", "e"), shape=(2, 1, 3, 1, 4)) - y_high = squeeze(x_high, ["b", "d"]) - fn = xr_function([x_high], y_high) - x_high_test = xr_arange_like(x_high) - res = fn(x_high_test) - expected_res = x_high_test.squeeze(["b", "d"]) - xr_assert_allclose(res, expected_res) + # Redundant dims + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 1, 1)) + y = squeeze(x, ["b", "b"]) + fn = xr_function([x], y) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.squeeze("b")) - # Test with symbolic shapes + # Order shouldn't matter + y1 = squeeze(x, ["b", "c"]) + y2 = squeeze(x, ["c", "b"]) + fn1 = xr_function([x], y1) + fn2 = xr_function([x], y2) + x_test = xr_arange_like(x) + xr_assert_allclose(fn1(x_test), fn2(x_test)) + + # Empty dims = no-op + y = squeeze(x, []) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test) + + # Squeeze all size-1 dims with dim=None + x = xtensor("x", dims=("a", "b"), shape=(1, 1)) + y = squeeze(x) + fn = xr_function([x], y) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.squeeze()) + + # Symbolic dims: squeeze static size 1 at runtime x_sym = xtensor("x_sym", dims=("a", "b", "c")) y_sym = squeeze(x_sym, "b") - x_sym_test = xr_arange_like(xtensor(dims=x_sym.dims, shape=(2, 1, 3))) + x_test = xr_arange_like(xtensor(dims=x_sym.dims, shape=(2, 1, 3))) fn = xr_function([x_sym], y_sym) - res = fn(x_sym_test) - expected_res = x_sym_test.squeeze("b") - xr_assert_allclose(res, expected_res) + xr_assert_allclose(fn(x_test), x_test.squeeze("b")) - # Test squeezing all dimensions of size 1 - x2d = xtensor("x2d", dims=("row", "col", "batch"), shape=(2, 1, 1)) - y2d = squeeze(x2d) - fn = xr_function([x2d], y2d) - x2d_test = xr_arange_like(x2d) - res = fn(x2d_test) - expected_res = x2d_test.squeeze() - xr_assert_allclose(res, expected_res) + # Symbolic 1-size dim + known dim + x = xtensor("x", dims=("a", "b", "c"), shape=(None, 1, 3)) + y = squeeze(x, "b") + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(4, 1, 3))) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.squeeze("b")) - # Test that squeezing a non-existent dimension raises an error - with pytest.raises(ValueError): + # Squeeze then expand_dims → reversible + x = xtensor("x", dims=("batch", "time", "feature"), shape=(2, 1, 3)) + y = squeeze(x, "time") + z = expand_dims(y, "time") + fn = xr_function([x], z) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test).transpose(*x_test.dims), x_test) + + # No dims to squeeze → no-op + x = xtensor("x", dims=("row", "col", "batch"), shape=(2, 3, 4)) + y = squeeze(x) + fn = xr_function([x], y) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test) + + +def test_squeeze_errors(): + # Squeeze nonexistent dim + x = xtensor("x", dims=("city", "country"), shape=(3, 1)) + with pytest.raises(ValueError, match="Dimension .* not found"): squeeze(x, "time") - # Test that squeezing a dimension of size > 1 raises an error - with pytest.raises(ValueError): + # Squeeze non-size-1 dim + with pytest.raises(ValueError, match="has static size .* not 1"): squeeze(x, "city") - # Test that squeezing when no dimensions are of size 1 raises an error - x3d = xtensor("x3d", dims=("row", "col", "batch"), shape=(2, 3, 4)) - with pytest.raises(ValueError): - squeeze(x3d) - - -def test_squeeze_additional_cases(): - # Redundant dimensions: squeeze(["b", "b"]) should behave like squeeze(["b"]) - x1 = xtensor("x1", dims=("a", "b", "c"), shape=(2, 1, 1)) - y1 = squeeze(x1, ["b", "b"]) - fn1 = xr_function([x1], y1) - x1_test = xr_arange_like(x1) - expected1 = x1_test.squeeze(["b"]) - xr_assert_allclose(fn1(x1_test), expected1) - - # Symbolic shape: dim is 1 at runtime → should squeeze successfully - x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown - y2 = squeeze(x2, "b") - fn2 = xr_function([x2], y2) - x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 1, 3))) - expected2 = x2_test.squeeze("b") - xr_assert_allclose(fn2(x2_test), expected2) - - # Symbolic shape: dim is not 1 at runtime → should raise - x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown - y3 = squeeze(x3, "b") - fn3 = xr_function([x3], y3) - x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 2, 3))) - with pytest.raises(Exception): - fn3(x3_test) - - # Reversibility: squeeze then expand_dims should restore original - x4 = xtensor("x4", dims=("batch", "time", "feature"), shape=(2, 1, 3)) - y4 = squeeze(x4, "time") - z4 = expand_dims(y4, "time") - fn4 = xr_function([x4], z4) - x4_test = xr_arange_like(x4) - # Adjust dimension order for comparison - xr_assert_allclose(fn4(x4_test).transpose(*x4_test.dims), x4_test) - - -def test_squeeze_extra_cases(): - # 1. Order of dims shouldn't affect result - x1 = xtensor("x1", dims=("a", "b", "c"), shape=(2, 1, 1)) - y1 = squeeze(x1, ["b", "c"]) - y2 = squeeze(x1, ["c", "b"]) - fn1 = xr_function([x1], y1) - fn2 = xr_function([x1], y2) - x1_test = xr_arange_like(x1) - xr_assert_allclose(fn1(x1_test), fn2(x1_test)) - - # 2. Empty list of dims = no-op - x2 = xtensor("x2", dims=("a", "b", "c"), shape=(2, 1, 1)) - y2 = squeeze(x2, []) - fn2 = xr_function([x2], y2) - x2_test = xr_arange_like(x2) - xr_assert_allclose(fn2(x2_test), x2_test) - - # 3. Explicit squeeze of all size-1 dims via dim=None - x3 = xtensor("x3", dims=("a", "b"), shape=(1, 1)) - y3 = squeeze(x3) - fn3 = xr_function([x3], y3) - x3_test = xr_arange_like(x3) - xr_assert_allclose(fn3(x3_test), x3_test.squeeze()) - - # 4. Static + symbolic shape mix: squeeze symbolic 1-sized dim - x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3)) - y4 = squeeze(x4, "b") - x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3))) - fn4 = xr_function([x4], y4) - expected4 = x4_test.squeeze("b") - xr_assert_allclose(fn4(x4_test), expected4) + # Symbolic dim is not size 1 at runtime + x = xtensor("x", dims=("a", "b", "c")) + y = squeeze(x, "b") + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 2, 3))) + fn = xr_function([x], y) + with pytest.raises(Exception): # shape assertion fails at runtime + fn(x_test) From 2120b1a5866d5d88383b18d431a4402536342d32 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 4 Jun 2025 14:25:51 -0400 Subject: [PATCH 09/21] Organizing squeeze tests --- tests/xtensor/test_shape.py | 181 ++++++++++++++++++------------------ 1 file changed, 90 insertions(+), 91 deletions(-) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index d2d24cee87..28cc6f28c4 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -394,100 +394,99 @@ def test_expand_dims(): xr_assert_allclose(fn(x_test), x_test) -def test_squeeze(): - # Basic squeeze - x = xtensor("x", dims=("city", "country"), shape=(3, 1)) - y = squeeze(x, "country") - fn = xr_function([x], y) - x_test = xr_arange_like(x) - expected = x_test.squeeze("country") - xr_assert_allclose(fn(x_test), expected) - - # Multiple dims - x_multi = xtensor("x_multi", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) - y_multi = squeeze(x_multi, ["b", "c"]) - fn = xr_function([x_multi], y_multi) - x_test = xr_arange_like(x_multi) - xr_assert_allclose(fn(x_test), x_test.squeeze(["b", "c"])) - - # All dims size 1 - x2d = xtensor("x2d", dims=("row", "col", "batch"), shape=(2, 1, 1)) - y2d = squeeze(x2d) - fn = xr_function([x2d], y2d) - x_test = xr_arange_like(x2d) - xr_assert_allclose(fn(x_test), x_test.squeeze()) - - # Redundant dims - x = xtensor("x", dims=("a", "b", "c"), shape=(2, 1, 1)) - y = squeeze(x, ["b", "b"]) - fn = xr_function([x], y) - x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), x_test.squeeze("b")) - - # Order shouldn't matter - y1 = squeeze(x, ["b", "c"]) - y2 = squeeze(x, ["c", "b"]) - fn1 = xr_function([x], y1) - fn2 = xr_function([x], y2) - x_test = xr_arange_like(x) - xr_assert_allclose(fn1(x_test), fn2(x_test)) - - # Empty dims = no-op - y = squeeze(x, []) - fn = xr_function([x], y) - xr_assert_allclose(fn(x_test), x_test) - - # Squeeze all size-1 dims with dim=None - x = xtensor("x", dims=("a", "b"), shape=(1, 1)) - y = squeeze(x) - fn = xr_function([x], y) - x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), x_test.squeeze()) - - # Symbolic dims: squeeze static size 1 at runtime - x_sym = xtensor("x_sym", dims=("a", "b", "c")) - y_sym = squeeze(x_sym, "b") - x_test = xr_arange_like(xtensor(dims=x_sym.dims, shape=(2, 1, 3))) - fn = xr_function([x_sym], y_sym) - xr_assert_allclose(fn(x_test), x_test.squeeze("b")) - - # Symbolic 1-size dim + known dim - x = xtensor("x", dims=("a", "b", "c"), shape=(None, 1, 3)) - y = squeeze(x, "b") - x_test = xr_arange_like(xtensor(dims=x.dims, shape=(4, 1, 3))) - fn = xr_function([x], y) - xr_assert_allclose(fn(x_test), x_test.squeeze("b")) - - # Squeeze then expand_dims → reversible - x = xtensor("x", dims=("batch", "time", "feature"), shape=(2, 1, 3)) - y = squeeze(x, "time") - z = expand_dims(y, "time") - fn = xr_function([x], z) - x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test).transpose(*x_test.dims), x_test) - - # No dims to squeeze → no-op - x = xtensor("x", dims=("row", "col", "batch"), shape=(2, 3, 4)) - y = squeeze(x) - fn = xr_function([x], y) - x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), x_test) +def test_squeeze_explicit_dims(): + """Test squeeze with explicit dimension(s).""" + + # Single dimension + x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) + y1 = squeeze(x1, "country") + fn1 = xr_function([x1], y1) + x1_test = xr_arange_like(x1) + xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country")) + + # Multiple dimensions + x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) + y2 = squeeze(x2, ["b", "c"]) + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"])) + + # Order independence + x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1)) + y3a = squeeze(x3, ["b", "c"]) + y3b = squeeze(x3, ["c", "b"]) + fn3a = xr_function([x3], y3a) + fn3b = xr_function([x3], y3b) + x3_test = xr_arange_like(x3) + xr_assert_allclose(fn3a(x3_test), fn3b(x3_test)) + + # Redundant dimensions + y3c = squeeze(x3, ["b", "b"]) + fn3c = xr_function([x3], y3c) + xr_assert_allclose(fn3c(x3_test), x3_test.squeeze("b")) + + # Empty list = no-op + y3d = squeeze(x3, []) + fn3d = xr_function([x3], y3d) + xr_assert_allclose(fn3d(x3_test), x3_test) + + +def test_squeeze_implicit_dims(): + """Test squeeze with implicit dim=None (all size-1 dimensions).""" + + # All dimensions size 1 + x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1)) + y1 = squeeze(x1) + fn1 = xr_function([x1], y1) + x1_test = xr_arange_like(x1) + xr_assert_allclose(fn1(x1_test), x1_test.squeeze()) + + # No dimensions size 1 = no-op + x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4)) + y2 = squeeze(x2) + fn2 = xr_function([x2], y2) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn2(x2_test), x2_test) + + # Symbolic shape where runtime shape is 1 → should squeeze + x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown + y3 = squeeze(x3, "b") + x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3))) + fn3 = xr_function([x3], y3) + xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b")) + + # Mixed static + symbolic shapes, where symbolic shape is 1 + x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3)) + y4 = squeeze(x4, "b") + x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3))) + fn4 = xr_function([x4], y4) + xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) + + # Reversibility with expand_dims + x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3)) + y5 = squeeze(x5, "time") + z5 = expand_dims(y5, "time") + fn5 = xr_function([x5], z5) + x5_test = xr_arange_like(x5) + xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test) def test_squeeze_errors(): - # Squeeze nonexistent dim - x = xtensor("x", dims=("city", "country"), shape=(3, 1)) + """Test error cases for squeeze.""" + + # Non-existent dimension + x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) with pytest.raises(ValueError, match="Dimension .* not found"): - squeeze(x, "time") + squeeze(x1, "time") - # Squeeze non-size-1 dim + # Dimension size > 1 with pytest.raises(ValueError, match="has static size .* not 1"): - squeeze(x, "city") - - # Symbolic dim is not size 1 at runtime - x = xtensor("x", dims=("a", "b", "c")) - y = squeeze(x, "b") - x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 2, 3))) - fn = xr_function([x], y) - with pytest.raises(Exception): # shape assertion fails at runtime - fn(x_test) + squeeze(x1, "city") + + # Symbolic shape: dim is not 1 at runtime → should raise + x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown + y2 = squeeze(x2, "b") + x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3))) + fn2 = xr_function([x2], y2) + with pytest.raises(Exception): + fn2(x2_test) From 2bb1fce99daac8084120966977b68ebea5370c3d Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 4 Jun 2025 15:12:57 -0400 Subject: [PATCH 10/21] Working on expand_dims --- pytensor/xtensor/rewriting/shape.py | 6 ++ pytensor/xtensor/shape.py | 61 +++++------ tests/xtensor/test_shape.py | 161 ++++++++++++++++------------ 3 files changed, 128 insertions(+), 100 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index bc374046cd..788a303b7c 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,9 +1,11 @@ import numpy as np from pytensor.graph import node_rewriter +from pytensor.raise_op import Assert from pytensor.tensor import ( broadcast_to, expand_dims, + gt, join, moveaxis, specify_shape, @@ -147,6 +149,10 @@ def local_expand_dims_reshape(fgraph, node): x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape) else: # Symbolic size: enforce shape so broadcast happens downstream correctly + # Also validate that size is positive + x_tensor_expanded = Assert(msg="size must be positive")( + x_tensor_expanded, gt(size, 0) + ) x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape) new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 70ca5e82c6..733467ae7c 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -306,23 +306,11 @@ def concat(xtensors, dim: str): class ExpandDims(XOp): - """Add a new dimension to an XTensorVariable. - - Parameters - ---------- - dim : str or None - The name of the new dimension. If None, no new dimension is added. - size : int or symbolic, optional - The size of the new dimension (default 1). - """ + """Add a new dimension to an XTensorVariable.""" __props__ = ("dims", "size") def __init__(self, dim, size=1): - if dim is not None and not isinstance(dim, str): - raise TypeError(f"`dim` must be a string or None, got: {type(dim)}") - if isinstance(size, int | np.integer) and size <= 0: - raise ValueError(f"size must be positive, got: {size}") self.dims = dim self.size = size @@ -330,23 +318,17 @@ def make_node(self, x): x = as_xtensor(x) if self.dims is None: + # No-op: return same variable return Apply(self, [x], [x]) - if self.dims in x.type.dims: - raise ValueError(f"Dimension {self.dims} already exists in {x.type.dims}") + # Insert new dim at front + new_dims = (self.dims, *x.type.dims) - # Handle scalar case - if not x.type.dims: - new_dims = (self.dims,) - new_shape = (self.size,) + # Determine shape + if isinstance(self.size, int | np.integer): + new_shape = (self.size, *x.type.shape) else: - # Use symbolic shape - new_dims = (self.dims, *x.type.dims) - if isinstance(self.size, int | np.integer): - new_shape = (self.size, *x.type.shape) - else: - # For symbolic size, we need to use a symbolic shape - new_shape = (None, *x.type.shape) + new_shape = (None, *x.type.shape) # symbolic size out = xtensor( dtype=x.type.dtype, @@ -368,17 +350,36 @@ def expand_dims(x, dim: str | None, size=1): Parameters ---------- x : XTensorVariable - The input tensor + Input tensor dim : str or None - The name of the new dimension + Name of new dimension. If None, returns x unchanged. size : int or symbolic, optional - The size of the new dimension (default 1) + Size of the new dimension (default 1) Returns ------- XTensorVariable - A new tensor with the expanded dimension + Tensor with the new dimension inserted """ + x = as_xtensor(x) + + if dim is None: + return x # No-op + + if not isinstance(dim, str): + raise TypeError(f"`dim` must be a string or None, got: {type(dim)}") + + if dim in x.type.dims: + raise ValueError(f"Dimension {dim} already exists in {x.type.dims}") + + if isinstance(size, int | np.integer): + if size <= 0: + raise ValueError(f"size must be positive, got: {size}") + elif not ( + hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar + ): + raise TypeError(f"size must be an int or scalar variable, got: {type(size)}") + return ExpandDims(dim=dim, size=size)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 28cc6f28c4..1be01ef5ba 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -265,133 +265,154 @@ def test_concat_scalar(): xr_assert_allclose(res, expected_res) -def test_expand_dims(): +def test_expand_dims_explicit(): + """Test expand_dims with explicitly named dimensions and sizes.""" + # 1D case - x_xr = xr.DataArray([0, 1, 2], dims=["city"]) - y_xr = x_xr.expand_dims("country") x = xtensor("x", dims=("city",), shape=(3,)) y = expand_dims(x, "country") fn = xr_function([x], y) - xr_assert_allclose(fn(x_xr), y_xr) + x_xr = xr_arange_like(x) + xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) # 2D case - x_xr = xr.DataArray([[0, 1], [2, 3]], dims=["city", "year"]) - y_xr = x_xr.expand_dims("country") x = xtensor("x", dims=("city", "year"), shape=(2, 2)) y = expand_dims(x, "country") fn = xr_function([x], y) - xr_assert_allclose(fn(x_xr), y_xr) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) # 3D case - x_xr = xr.DataArray( - [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dims=["city", "year", "month"] - ) - y_xr = x_xr.expand_dims("country") x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) y = expand_dims(x, "country") fn = xr_function([x], y) - xr_assert_allclose(fn(x_xr), y_xr) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) - # Test that expand_dims is reversible with squeeze - x_xr = xr.DataArray([0, 1, 2], dims=["city"]) - y_xr = x_xr.expand_dims("country") - z_xr = y_xr.squeeze("country") - x = xtensor("x", dims=("city",), shape=(3,)) - y = expand_dims(x, "country") - z = squeeze(y, "country") - fn = xr_function([x], z) - xr_assert_allclose(fn(x_xr), z_xr) - - # Expanding with an existing dimension raises an error - with pytest.raises(ValueError, match="already exists"): - expand_dims(y, "city") - - # Expanding with None dimension should return the same variable (no-op) - z = expand_dims(x, None) - assert z is x - - # Test prepending different dimension names + # Prepending various dims x = xtensor("x", dims=("a", "b"), shape=(2, 3)) for new_dim in ("x", "y", "z"): y = expand_dims(x, new_dim) - expected_dims = (new_dim, *x.type.dims) - expected_shape = (1, *x.type.shape) - assert y.type.dims == expected_dims - assert y.type.shape == expected_shape + assert y.type.dims == (new_dim, "a", "b") + assert y.type.shape == (1, 2, 3) - # Explicit size=1 behaves like implicit broadcast - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - x_test = xr_arange_like(x) + # Explicit size=1 behaves like default y1 = expand_dims(x, "batch", size=1) y2 = expand_dims(x, "batch") fn1 = xr_function([x], y1) fn2 = xr_function([x], y2) + x_test = xr_arange_like(x) xr_assert_allclose(fn1(x_test), fn2(x_test)) - # Expanding with size=0 raises - with pytest.raises(ValueError, match="size must be.*positive"): - expand_dims(x, "batch", size=0) - - # Expanding a scalar + # Scalar expansion x = xtensor("x", dims=(), shape=()) y = expand_dims(x, "batch") assert y.type.dims == ("batch",) assert y.type.shape == (1,) fn = xr_function([x], y) - x_test = xr_arange_like(x) - expected = x_test.expand_dims("batch") - xr_assert_allclose(fn(x_test), expected) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) - # Expanding with a specified static size > 1 + # Static size > 1: broadcast x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = expand_dims(x, "batch", size=4) - assert y.type.dims == ("batch", "a", "b") - assert y.type.shape == (4, 2, 3) fn = xr_function([x], y) - x_test = xr_arange_like(x) expected = xr.DataArray( - np.broadcast_to(x_test.data, (4, 2, 3)), + np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), dims=("batch", "a", "b"), - coords={"a": x_test.coords["a"], "b": x_test.coords["b"]}, + coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, ) + xr_assert_allclose(fn(xr_arange_like(x)), expected) + + # Insert new dim between existing dims + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "new") + # Insert new dim between a and b: ("a", "new", "b") + y = transpose(y, "a", "new", "b") + fn = xr_function([x], y) + x_test = xr_arange_like(x) + expected = x_test.expand_dims("new").transpose("a", "new", "b") xr_assert_allclose(fn(x_test), expected) - # Expanding with symbolic size = 1 (no broadcast) + # Expand with multiple dims + x = xtensor("x", dims=(), shape=()) + y = expand_dims(expand_dims(x, "a"), "b") + fn = xr_function([x], y) + expected = xr_arange_like(x).expand_dims("a").expand_dims("b") + xr_assert_allclose(fn(xr_arange_like(x)), expected) + + +def test_expand_dims_implicit(): + """Test expand_dims with default or symbolic sizes and dim=None.""" + + # Symbolic size=1: same as default size_sym_1 = scalar("size_sym_1", dtype="int64") x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = expand_dims(x, "batch", size=size_sym_1) fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") - x_test = xr_arange_like(x) - expected = x_test.expand_dims("batch") - xr_assert_allclose(fn(x_test, 1), expected) + expected = xr_arange_like(x).expand_dims("batch") + xr_assert_allclose(fn(xr_arange_like(x), 1), expected) - # Expanding with symbolic size > 1 (but no broadcast expected) + # Symbolic size > 1 (but expand only adds dim=1) size_sym_4 = scalar("size_sym_4", dtype="int64") - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = expand_dims(x, "batch", size=size_sym_4) fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") - x_test = xr_arange_like(x) - - # Even if symbolic size is 4, expand_dims will only insert dim=1 - expected = x_test.expand_dims("batch") - xr_assert_allclose(fn(x_test, 4), expected) + xr_assert_allclose(fn(xr_arange_like(x), 4), expected) - # Reversibility: expand_dims then squeeze + # Reversibility: expand then squeeze x = xtensor("x", dims=("a", "b"), shape=(2, 3)) y = expand_dims(x, "batch") z = squeeze(y, "batch") fn = xr_function([x], z) - x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), x_test) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) - # Expand with dim=None is a no-op + # expand_dims with dim=None = no-op x = xtensor("x", dims=("a",), shape=(3,)) y = expand_dims(x, None) - assert y.type.dims == x.type.dims - assert y.type.shape == x.type.shape fn = xr_function([x], y) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) + + # broadcast after symbolic size + size_sym = scalar("size_sym", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym) + z = y + y # triggers shape alignment + fn = xr_function([x, size_sym], z, on_unused_input="ignore") x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), x_test) + out = fn(x_test, 1) + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") + xr_assert_allclose(out, expected) + + +def test_expand_dims_errors(): + """Test error handling in expand_dims.""" + + # Expanding existing dim + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, "country") + with pytest.raises(ValueError, match="already exists"): + expand_dims(y, "city") + + # Size = 0 is invalid + with pytest.raises(ValueError, match="size must be.*positive"): + expand_dims(x, "batch", size=0) + + # Invalid dim type + with pytest.raises(TypeError): + expand_dims(x, 123) + + # Invalid size type + with pytest.raises(TypeError): + expand_dims(x, "new", size=[1]) + + # Duplicate dimension creation + y = expand_dims(x, "new") + with pytest.raises(ValueError): + expand_dims(y, "new") + + # Symbolic size with invalid runtime value + size_sym = scalar("size_sym", dtype="int64") + y = expand_dims(x, "batch", size=size_sym) + fn = xr_function([x, size_sym], y, on_unused_input="ignore") + with pytest.raises(Exception): + fn(xr_arange_like(x), 0) def test_squeeze_explicit_dims(): From 7a308b90e9a773d75b33340589949d4331b97d18 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 4 Jun 2025 15:17:27 -0400 Subject: [PATCH 11/21] lint --- tests/xtensor/test_shape.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 6e38f870f5..6231eb3e28 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -1,6 +1,7 @@ # ruff: noqa: E402 import pytest + pytest.importorskip("xarray") import re From 9c1a0b7c2011963eb2e14e0687ad515622944752 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 11:01:08 -0400 Subject: [PATCH 12/21] Update pytensor/xtensor/shape.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/xtensor/shape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 733467ae7c..d5da400947 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -397,8 +397,8 @@ class Squeeze(XOp): __props__ = ("dims",) - def __init__(self, dim): - self.dims = dim + def __init__(self, dims): + self.dims = dims def make_node(self, x): x = as_xtensor(x) From 102479855c0b552ef2f2fd984f5d32dcdecc4f06 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 11:25:37 -0400 Subject: [PATCH 13/21] Cleaning up squeeze --- pytensor/xtensor/rewriting/shape.py | 22 ++++++---------------- pytensor/xtensor/shape.py | 15 +++++++-------- tests/xtensor/test_shape.py | 18 +++++++++++++++++- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 788a303b7c..fc0f7b5a73 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -162,28 +162,18 @@ def local_expand_dims_reshape(fgraph, node): @register_xcanonicalize @node_rewriter([Squeeze]) def local_squeeze_reshape(fgraph, node): - """Rewrite rule to convert Squeeze to pytensor.tensor.squeeze.""" - if not isinstance(node.op, Squeeze): - return False - - [x] = node.inputs - in_dims = x.type.dims + """Rewrite Squeeze to tensor.squeeze.""" + x = node.inputs[0] dim = node.op.dims - # Determine which axes to squeeze if dim is None: - # Infer axes by comparing input and output dims - out_dims = node.outputs[0].type.dims - axes_to_squeeze = tuple(i for i, d in enumerate(in_dims) if d not in out_dims) - else: - dims_to_remove = [dim] if isinstance(dim, str) else dim - axes_to_squeeze = tuple(in_dims.index(d) for d in dims_to_remove) - - # Nothing to squeeze? Just return input unchanged - if not axes_to_squeeze: return [x] x_tensor = tensor_from_xtensor(x) + x_dims = x.type.dims + dims_to_remove = [dim] if isinstance(dim, str) else dim + axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove) x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze) + new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index d5da400947..7385b726d8 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -397,19 +397,21 @@ class Squeeze(XOp): __props__ = ("dims",) - def __init__(self, dims): - self.dims = dims + def __init__(self, dim): + self.dims = tuple(sorted(set(dim))) def make_node(self, x): x = as_xtensor(x) # Validate that dims exist and are size-1 if statically known dims_to_remove = [] + x_dims = x.type.dims + x_shape = x.type.shape for d in self.dims: - if d not in x.type.dims: + if d not in x_dims: raise ValueError(f"Dimension {d} not found in {x.type.dims}") - idx = x.type.dims.index(d) - dim_size = x.type.shape[idx] + idx = x_dims.index(d) + dim_size = x_shape[idx] if dim_size is not None and dim_size != 1: raise ValueError(f"Dimension {d} has static size {dim_size}, not 1") dims_to_remove.append(idx) @@ -454,9 +456,6 @@ def squeeze(x, dim=None): else: dims = tuple(dim) - # Normalize: deduplicate and sort - dims = tuple(sorted(set(dims))) - if not dims: return x # no-op if nothing to squeeze diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 6231eb3e28..582ba09dc5 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -448,7 +448,7 @@ def test_squeeze_explicit_dims(): # Redundant dimensions y3c = squeeze(x3, ["b", "b"]) fn3c = xr_function([x3], y3c) - xr_assert_allclose(fn3c(x3_test), x3_test.squeeze("b")) + xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"])) # Empty list = no-op y3d = squeeze(x3, []) @@ -495,6 +495,22 @@ def test_squeeze_implicit_dims(): x5_test = xr_arange_like(x5) xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test) + """ + This test documents that we intentionally don't squeeze dimensions with symbolic shapes + (static_shape=None) even when they are 1 at runtime, while xarray does squeeze them. + """ + # Create a tensor with a symbolic dimension that will be 1 at runtime + x = xtensor("x", dims=("a", "b", "c")) # shape unknown + y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3))) + fn = xr_function([x], y) + res = fn(x_test) + + # Our implementation should not squeeze the symbolic dimension + assert "b" in res.dims + # While xarray would squeeze it + assert "b" not in x_test.squeeze().dims + def test_squeeze_errors(): """Test error cases for squeeze.""" From 260b9b6f51b21253fb49402c76e9e6ef5185db9a Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 11:56:43 -0400 Subject: [PATCH 14/21] Removing expand_dims --- pytensor/xtensor/rewriting/shape.py | 40 ------- pytensor/xtensor/shape.py | 80 ------------- tests/xtensor/test_shape.py | 167 ++-------------------------- 3 files changed, 7 insertions(+), 280 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index fc0f7b5a73..16c0952198 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,11 +1,6 @@ -import numpy as np - from pytensor.graph import node_rewriter -from pytensor.raise_op import Assert from pytensor.tensor import ( broadcast_to, - expand_dims, - gt, join, moveaxis, specify_shape, @@ -15,7 +10,6 @@ from pytensor.xtensor.rewriting.basic import register_xcanonicalize from pytensor.xtensor.shape import ( Concat, - ExpandDims, Squeeze, Stack, Transpose, @@ -125,40 +119,6 @@ def lower_transpose(fgraph, node): return [new_out] -@register_xcanonicalize -@node_rewriter([ExpandDims]) -def local_expand_dims_reshape(fgraph, node): - """Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape.""" - if not isinstance(node.op, ExpandDims): - return False - - x = node.inputs[0] - dim = node.op.dims - size = getattr(node.op, "size", 1) - - if dim is None: - return [x] - - x_tensor = tensor_from_xtensor(x) - x_tensor_expanded = expand_dims(x_tensor, axis=0) - - target_shape = node.outputs[0].type.shape - - if isinstance(size, int | np.integer): - if size != 1 and None not in target_shape: - x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape) - else: - # Symbolic size: enforce shape so broadcast happens downstream correctly - # Also validate that size is positive - x_tensor_expanded = Assert(msg="size must be positive")( - x_tensor_expanded, gt(size, 0) - ) - x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape) - - new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims) - return [new_out] - - @register_xcanonicalize @node_rewriter([Squeeze]) def local_squeeze_reshape(fgraph, node): diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 7385b726d8..9b43face06 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -2,8 +2,6 @@ from collections.abc import Sequence from typing import Literal -import numpy as np - from pytensor import Variable from pytensor.graph import Apply from pytensor.scalar import discrete_dtypes, upcast @@ -305,84 +303,6 @@ def concat(xtensors, dim: str): return Concat(dim=dim)(*xtensors) -class ExpandDims(XOp): - """Add a new dimension to an XTensorVariable.""" - - __props__ = ("dims", "size") - - def __init__(self, dim, size=1): - self.dims = dim - self.size = size - - def make_node(self, x): - x = as_xtensor(x) - - if self.dims is None: - # No-op: return same variable - return Apply(self, [x], [x]) - - # Insert new dim at front - new_dims = (self.dims, *x.type.dims) - - # Determine shape - if isinstance(self.size, int | np.integer): - new_shape = (self.size, *x.type.shape) - else: - new_shape = (None, *x.type.shape) # symbolic size - - out = xtensor( - dtype=x.type.dtype, - shape=new_shape, - dims=new_dims, - ) - return Apply(self, [x], [out]) - - def infer_shape(self, fgraph, node, input_shapes): - (input_shape,) = input_shapes - if self.dims is None: - return [input_shape] - return [(self.size, *list(input_shape))] - - -def expand_dims(x, dim: str | None, size=1): - """Add a new dimension to an XTensorVariable. - - Parameters - ---------- - x : XTensorVariable - Input tensor - dim : str or None - Name of new dimension. If None, returns x unchanged. - size : int or symbolic, optional - Size of the new dimension (default 1) - - Returns - ------- - XTensorVariable - Tensor with the new dimension inserted - """ - x = as_xtensor(x) - - if dim is None: - return x # No-op - - if not isinstance(dim, str): - raise TypeError(f"`dim` must be a string or None, got: {type(dim)}") - - if dim in x.type.dims: - raise ValueError(f"Dimension {dim} already exists in {x.type.dims}") - - if isinstance(size, int | np.integer): - if size <= 0: - raise ValueError(f"size must be positive, got: {size}") - elif not ( - hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar - ): - raise TypeError(f"size must be an int or scalar variable, got: {type(size)}") - - return ExpandDims(dim=dim, size=size)(x) - - class Squeeze(XOp): """Remove specified dimensions from an XTensorVariable. diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 582ba09dc5..383cae6342 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -9,14 +9,11 @@ import numpy as np import pytest -import xarray as xr from xarray import DataArray from xarray import concat as xr_concat -from pytensor.tensor import scalar from pytensor.xtensor.shape import ( concat, - expand_dims, squeeze, stack, transpose, @@ -269,156 +266,6 @@ def test_concat_scalar(): xr_assert_allclose(res, expected_res) -def test_expand_dims_explicit(): - """Test expand_dims with explicitly named dimensions and sizes.""" - - # 1D case - x = xtensor("x", dims=("city",), shape=(3,)) - y = expand_dims(x, "country") - fn = xr_function([x], y) - x_xr = xr_arange_like(x) - xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) - - # 2D case - x = xtensor("x", dims=("city", "year"), shape=(2, 2)) - y = expand_dims(x, "country") - fn = xr_function([x], y) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) - - # 3D case - x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) - y = expand_dims(x, "country") - fn = xr_function([x], y) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) - - # Prepending various dims - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - for new_dim in ("x", "y", "z"): - y = expand_dims(x, new_dim) - assert y.type.dims == (new_dim, "a", "b") - assert y.type.shape == (1, 2, 3) - - # Explicit size=1 behaves like default - y1 = expand_dims(x, "batch", size=1) - y2 = expand_dims(x, "batch") - fn1 = xr_function([x], y1) - fn2 = xr_function([x], y2) - x_test = xr_arange_like(x) - xr_assert_allclose(fn1(x_test), fn2(x_test)) - - # Scalar expansion - x = xtensor("x", dims=(), shape=()) - y = expand_dims(x, "batch") - assert y.type.dims == ("batch",) - assert y.type.shape == (1,) - fn = xr_function([x], y) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) - - # Static size > 1: broadcast - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=4) - fn = xr_function([x], y) - expected = xr.DataArray( - np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), - dims=("batch", "a", "b"), - coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, - ) - xr_assert_allclose(fn(xr_arange_like(x)), expected) - - # Insert new dim between existing dims - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "new") - # Insert new dim between a and b: ("a", "new", "b") - y = transpose(y, "a", "new", "b") - fn = xr_function([x], y) - x_test = xr_arange_like(x) - expected = x_test.expand_dims("new").transpose("a", "new", "b") - xr_assert_allclose(fn(x_test), expected) - - # Expand with multiple dims - x = xtensor("x", dims=(), shape=()) - y = expand_dims(expand_dims(x, "a"), "b") - fn = xr_function([x], y) - expected = xr_arange_like(x).expand_dims("a").expand_dims("b") - xr_assert_allclose(fn(xr_arange_like(x)), expected) - - -def test_expand_dims_implicit(): - """Test expand_dims with default or symbolic sizes and dim=None.""" - - # Symbolic size=1: same as default - size_sym_1 = scalar("size_sym_1", dtype="int64") - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym_1) - fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") - expected = xr_arange_like(x).expand_dims("batch") - xr_assert_allclose(fn(xr_arange_like(x), 1), expected) - - # Symbolic size > 1 (but expand only adds dim=1) - size_sym_4 = scalar("size_sym_4", dtype="int64") - y = expand_dims(x, "batch", size=size_sym_4) - fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") - xr_assert_allclose(fn(xr_arange_like(x), 4), expected) - - # Reversibility: expand then squeeze - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch") - z = squeeze(y, "batch") - fn = xr_function([x], z) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) - - # expand_dims with dim=None = no-op - x = xtensor("x", dims=("a",), shape=(3,)) - y = expand_dims(x, None) - fn = xr_function([x], y) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) - - # broadcast after symbolic size - size_sym = scalar("size_sym", dtype="int64") - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym) - z = y + y # triggers shape alignment - fn = xr_function([x, size_sym], z, on_unused_input="ignore") - x_test = xr_arange_like(x) - out = fn(x_test, 1) - expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") - xr_assert_allclose(out, expected) - - -def test_expand_dims_errors(): - """Test error handling in expand_dims.""" - - # Expanding existing dim - x = xtensor("x", dims=("city",), shape=(3,)) - y = expand_dims(x, "country") - with pytest.raises(ValueError, match="already exists"): - expand_dims(y, "city") - - # Size = 0 is invalid - with pytest.raises(ValueError, match="size must be.*positive"): - expand_dims(x, "batch", size=0) - - # Invalid dim type - with pytest.raises(TypeError): - expand_dims(x, 123) - - # Invalid size type - with pytest.raises(TypeError): - expand_dims(x, "new", size=[1]) - - # Duplicate dimension creation - y = expand_dims(x, "new") - with pytest.raises(ValueError): - expand_dims(y, "new") - - # Symbolic size with invalid runtime value - size_sym = scalar("size_sym", dtype="int64") - y = expand_dims(x, "batch", size=size_sym) - fn = xr_function([x, size_sym], y, on_unused_input="ignore") - with pytest.raises(Exception): - fn(xr_arange_like(x), 0) - - def test_squeeze_explicit_dims(): """Test squeeze with explicit dimension(s).""" @@ -487,13 +334,13 @@ def test_squeeze_implicit_dims(): fn4 = xr_function([x4], y4) xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) - # Reversibility with expand_dims - x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3)) - y5 = squeeze(x5, "time") - z5 = expand_dims(y5, "time") - fn5 = xr_function([x5], z5) - x5_test = xr_arange_like(x5) - xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test) + # Reversibility with expand_dims (restore when expand_dims is implemented) + # x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3)) + # y5 = squeeze(x5, "time") + # z5 = expand_dims(y5, "time") + # fn5 = xr_function([x5], z5) + # x5_test = xr_arange_like(x5) + # xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test) """ This test documents that we intentionally don't squeeze dimensions with symbolic shapes From 915a3682ac50d17c6b13147b3991975b4d69069a Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 11:58:30 -0400 Subject: [PATCH 15/21] Removing unneded check --- pytensor/xtensor/rewriting/shape.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 16c0952198..d590eb8a12 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -126,9 +126,6 @@ def local_squeeze_reshape(fgraph, node): x = node.inputs[0] dim = node.op.dims - if dim is None: - return [x] - x_tensor = tensor_from_xtensor(x) x_dims = x.type.dims dims_to_remove = [dim] if isinstance(dim, str) else dim From 05dac9ec9dc9257b8e034515f5a01de8e2bf547b Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 13:19:16 -0400 Subject: [PATCH 16/21] Update pytensor/xtensor/shape.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/xtensor/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 9b43face06..d9ec15079b 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -360,7 +360,7 @@ def squeeze(x, dim=None): The input tensor dim : str or None or iterable of str, optional The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 - (known statically) will be removed. Dimensions with symbolic shape will be retained. + (known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime. Returns ------- From 98d297e3d1610a054ef063529001783f9afc7323 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 13:20:18 -0400 Subject: [PATCH 17/21] Update pytensor/xtensor/shape.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/xtensor/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index d9ec15079b..87b6be6e80 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -379,4 +379,4 @@ def squeeze(x, dim=None): if not dims: return x # no-op if nothing to squeeze - return Squeeze(dim=dims)(x) + return Squeeze(dims=dims)(x) From 3202c4c93d4849453a29a0dd32d748386d7ba1ac Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 13:21:30 -0400 Subject: [PATCH 18/21] Update pytensor/xtensor/rewriting/shape.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/xtensor/rewriting/shape.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index d590eb8a12..55d7e388e4 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -124,11 +124,9 @@ def lower_transpose(fgraph, node): def local_squeeze_reshape(fgraph, node): """Rewrite Squeeze to tensor.squeeze.""" x = node.inputs[0] - dim = node.op.dims - x_tensor = tensor_from_xtensor(x) x_dims = x.type.dims - dims_to_remove = [dim] if isinstance(dim, str) else dim + dims_to_remove = node.op.dims axes_to_squeeze = tuple(x_dims.index(d) for d in dims_to_remove) x_tensor_squeezed = squeeze(x_tensor, axis=axes_to_squeeze) From 8d4fdd596887fc6cc9fe4f46d58e61e441171794 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 13:21:59 -0400 Subject: [PATCH 19/21] Update pytensor/xtensor/shape.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/xtensor/shape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 87b6be6e80..329e8a5163 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -317,8 +317,8 @@ class Squeeze(XOp): __props__ = ("dims",) - def __init__(self, dim): - self.dims = tuple(sorted(set(dim))) + def __init__(self, dims): + self.dims = tuple(sorted(set(dims))) def make_node(self, x): x = as_xtensor(x) From f000bbb53b5c54d07841d91c17d3153bf5b08fca Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 13:31:26 -0400 Subject: [PATCH 20/21] All but one requested change --- tests/xtensor/test_shape.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 383cae6342..fcb7e03757 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -334,14 +334,6 @@ def test_squeeze_implicit_dims(): fn4 = xr_function([x4], y4) xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) - # Reversibility with expand_dims (restore when expand_dims is implemented) - # x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3)) - # y5 = squeeze(x5, "time") - # z5 = expand_dims(y5, "time") - # fn5 = xr_function([x5], z5) - # x5_test = xr_arange_like(x5) - # xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test) - """ This test documents that we intentionally don't squeeze dimensions with symbolic shapes (static_shape=None) even when they are 1 at runtime, while xarray does squeeze them. From 2de95666c88577c0aa6646b78e4eca5ad5810b29 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 5 Jun 2025 16:04:59 -0400 Subject: [PATCH 21/21] Picking a nit --- pytensor/xtensor/rewriting/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 55d7e388e4..efe92fe367 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -123,7 +123,7 @@ def lower_transpose(fgraph, node): @node_rewriter([Squeeze]) def local_squeeze_reshape(fgraph, node): """Rewrite Squeeze to tensor.squeeze.""" - x = node.inputs[0] + [x] = node.inputs x_tensor = tensor_from_xtensor(x) x_dims = x.type.dims dims_to_remove = node.op.dims