Skip to content

Add squeeze for labeled tensors #1434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: labeled_tensors
Choose a base branch
from

Conversation

AllenDowney
Copy link

@AllenDowney AllenDowney commented May 30, 2025

Adding squeeze and expand_dims


📚 Documentation preview 📚: https://pytensor--1434.org.readthedocs.build/en/1434/

… ExpandDims op and rewrite rule to not add a new dimension when dim is None - Update tests to verify behavior matches xarray
@AllenDowney
Copy link
Author

@ricardoV94 Please take a look at squeeze. expand_dims is still WIP

XTensorVariable
A new tensor with the specified dimension removed
"""
return Squeeze(dim=dim)(x)
Copy link
Member

@ricardoV94 ricardoV94 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better not to have None in the Op. Do the conversion here and pass explicit dims to the Op. The reason for this has to do with PyTensor constraints.

Our Squeeze Op should always know which explicit dims are do be dropped, because the input could change subtly during rewrites, and now we find out a dimension has length 1 after all, which we didn't know before, and reapplying the same Op will change the output type, which is not allowed during rewrites.


Another note, xarray squeeze seems to accept axis argument to do positional squeeze, we should allow that and convert to dims: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.squeeze.html#xarray-dataarray-squeeze

Better to always check the docs of the xarray method we're trying to emulate to be aware of special arguments

You may need to experiment a bit about what does xarray do if you specify both, or specify invalid dims/axis, to try and emulate the behavior on our side as much as is reasonable for us to do.

@AllenDowney AllenDowney changed the title Add expand dims squeeze Add squeeze Jun 2, 2025
@AllenDowney
Copy link
Author

@ricardoV94 I have restored the version with tests that validate against xarray behavior. I think squeeze is ready for review. Again, ignore expand_dims for now.

I have a question about the case where a dimension specifier is symbolic -- is the implementation here correct?

Comment on lines +128 to +129
if not isinstance(node.op, ExpandDims):
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check isn't needed, the node_rewriter argument is already used to preselect such nodes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll not that for the next iteration, but expand_dims is not ready for review.

Comment on lines 135 to 137
# If dim is None, don't add a new dimension (matching xarray behavior)
if dim is None:
return [x]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to support this at the Op level, just make it return self when x.expand_dims(None) is called if we want to even support that

return [x]

# Create new dimensions list with the new dimension at the beginning
new_dims = [dim, *list(x.type.dims)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should support multiple expand_dims, not only one?


x = node.inputs[0]
dim = node.op.dim
size = getattr(node.op, "size", 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size should be a symbolic input (or multiple if we have multiple dims) to the node, so you'll have x, *sizes = node.inputs. This way they can be arbitrary symbolic expressions and not just constants. Check how unstack does it.

Comment on lines +158 to +159
if not isinstance(node.op, Squeeze):
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed

if not isinstance(node.op, Squeeze):
return False

x = node.inputs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, I like to do [x] = node.inputs to be explicit that this is a single input node

Comment on lines 162 to 165
dim = node.op.dim

# Convert single dimension to iterable for consistent handling
dims_to_remove = [dim] if isinstance(dim, str) else dim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sort of normalization should be done at the time the Op/node is defined. The earlier we normalize stuff the easier it is to work downstream.

Comment on lines 178 to 182
else:
# Find all dimensions of size 1
dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1]
if not dim_indices:
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't happen at rewrite time. Decide at the time you create the Op/node what dimensions will be dropped and stick to those. This is a case where PyTensor deviates from numpy/xarray, due to it's non-eager nature and the ability to work with unknown shapes.

You can see this happening in pytensor like this:

import pytensor.tensor as pt
import numpy as np

x = pt.tensor("x", shape=(None, 2, 1, 2, None))
y = x.squeeze()
assert y.eval({x: np.zeros((1, 2, 1, 2, 1))}).shape  == (1, 2, 2, 1)

Only the dimension we knew to be length 1 when x.squeeze() was called was dropped. We never try to update which dimension we drop, because y is bound to it's type y.type, that cannot change during rewrites (well shape can go from None -> int), but ndim cannot change.

return False

# Create new dimensions list
new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reuse node.outputs[0].type.dims since you already did the work of figuring out the output dims in make_node



def squeeze(x, dim=None):
"""Remove dimensions of size 1 from an XTensorVariable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that this deviates from numpy/xarray. Similar to what we have here:

2. Similarly, if `axis` is ``None``, only dimensions known to be broadcastable will be
removed, even if there are more dimensions that happen to be broadcastable when
the variable is evaluated.

The first point is actually not true anymore, so don't copy it

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 2, 2025

I have a question about the case where a dimension specifier is symbolic -- is the implementation here correct?

I replied in the comment. It's not correct. Symbolic inputs have to show up in make_node not __init__. You can try to create such a graph like this (adapt it to a test format):

import pytensor
import pytensor.xtensor as px

x = px.xtensor("x", shape=(2,), dims=("a",))
b_size = px.xtensor("b_size", shape=(), dims=())
y = y.expand_dims(b=b_size)

y.eval({x: np.array([0, 1]), b_size: np.array(10)})

If you try this right now you may get an error or a silent bug, because b_size is not part of the symbolic graph of y as far as PyTensor can tell

@AllenDowney
Copy link
Author

@ricardoV94 squeeze is ready for another look.

expand_dims is still not ready for review

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did another pass on the Squeeze functionality.

I suggested removing the None case from the Op level and left some other minor comments about it.

I think the tests right now are a bit overkill / redundant / messy. I suggest grouping in different functions the following things:

  1. Tests with explicit squeeze dim (single, multiple, order independent)
  2. Tests with implicit None dim (including the case that at runtime deviates from xarray and as documented)
  3. Tests for errors raised by the Op at creation or runtime

to be size 1 at runtime.
"""

__props__ = ("dim",)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: call the Op prop dims instead of dim (still use dim in the user facing functions)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but it means that dim and dims are all over the place now. Worth it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean they are all over the place now? Why is that?

@AllenDowney
Copy link
Author

@ricardoV94 squeeze is ready for another look, and expand_dims is ready, too

@twiecki twiecki changed the title Add squeeze Add squeeze for labeled tensors Jun 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants