-
Notifications
You must be signed in to change notification settings - Fork 133
Add squeeze for labeled tensors #1434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Add squeeze for labeled tensors #1434
Conversation
… ExpandDims op and rewrite rule to not add a new dimension when dim is None - Update tests to verify behavior matches xarray
…and streamlining logic.
@ricardoV94 Please take a look at squeeze. expand_dims is still WIP |
pytensor/xtensor/shape.py
Outdated
XTensorVariable | ||
A new tensor with the specified dimension removed | ||
""" | ||
return Squeeze(dim=dim)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better not to have None
in the Op
. Do the conversion here and pass explicit dims to the Op. The reason for this has to do with PyTensor constraints.
Our Squeeze Op should always know which explicit dims are do be dropped, because the input could change subtly during rewrites, and now we find out a dimension has length 1 after all, which we didn't know before, and reapplying the same Op will change the output type, which is not allowed during rewrites.
Another note, xarray squeeze seems to accept axis argument to do positional squeeze, we should allow that and convert to dims: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.squeeze.html#xarray-dataarray-squeeze
Better to always check the docs of the xarray method we're trying to emulate to be aware of special arguments
You may need to experiment a bit about what does xarray do if you specify both, or specify invalid dims/axis, to try and emulate the behavior on our side as much as is reasonable for us to do.
@ricardoV94 I have restored the version with tests that validate against xarray behavior. I think squeeze is ready for review. Again, ignore expand_dims for now. I have a question about the case where a dimension specifier is symbolic -- is the implementation here correct? |
if not isinstance(node.op, ExpandDims): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check isn't needed, the node_rewriter
argument is already used to preselect such nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll not that for the next iteration, but expand_dims is not ready for review.
pytensor/xtensor/rewriting/shape.py
Outdated
# If dim is None, don't add a new dimension (matching xarray behavior) | ||
if dim is None: | ||
return [x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to support this at the Op level, just make it return self when x.expand_dims(None)
is called if we want to even support that
pytensor/xtensor/rewriting/shape.py
Outdated
return [x] | ||
|
||
# Create new dimensions list with the new dimension at the beginning | ||
new_dims = [dim, *list(x.type.dims)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should support multiple expand_dims, not only one?
|
||
x = node.inputs[0] | ||
dim = node.op.dim | ||
size = getattr(node.op, "size", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size should be a symbolic input (or multiple if we have multiple dims) to the node, so you'll have x, *sizes = node.inputs
. This way they can be arbitrary symbolic expressions and not just constants. Check how unstack
does it.
if not isinstance(node.op, Squeeze): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed
pytensor/xtensor/rewriting/shape.py
Outdated
if not isinstance(node.op, Squeeze): | ||
return False | ||
|
||
x = node.inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick, I like to do [x] = node.inputs
to be explicit that this is a single input node
pytensor/xtensor/rewriting/shape.py
Outdated
dim = node.op.dim | ||
|
||
# Convert single dimension to iterable for consistent handling | ||
dims_to_remove = [dim] if isinstance(dim, str) else dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sort of normalization should be done at the time the Op/node is defined. The earlier we normalize stuff the easier it is to work downstream.
pytensor/xtensor/rewriting/shape.py
Outdated
else: | ||
# Find all dimensions of size 1 | ||
dim_indices = [i for i, s in enumerate(x.type.shape) if s == 1] | ||
if not dim_indices: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't happen at rewrite time. Decide at the time you create the Op/node what dimensions will be dropped and stick to those. This is a case where PyTensor deviates from numpy/xarray, due to it's non-eager nature and the ability to work with unknown shapes.
You can see this happening in pytensor like this:
import pytensor.tensor as pt
import numpy as np
x = pt.tensor("x", shape=(None, 2, 1, 2, None))
y = x.squeeze()
assert y.eval({x: np.zeros((1, 2, 1, 2, 1))}).shape == (1, 2, 2, 1)
Only the dimension we knew to be length 1 when x.squeeze()
was called was dropped. We never try to update which dimension we drop, because y
is bound to it's type y.type
, that cannot change during rewrites (well shape can go from None -> int), but ndim cannot change.
pytensor/xtensor/rewriting/shape.py
Outdated
return False | ||
|
||
# Create new dimensions list | ||
new_dims = [d for i, d in enumerate(x.type.dims) if i not in dim_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just reuse node.outputs[0].type.dims
since you already did the work of figuring out the output dims in make_node
|
||
|
||
def squeeze(x, dim=None): | ||
"""Remove dimensions of size 1 from an XTensorVariable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note that this deviates from numpy/xarray. Similar to what we have here:
pytensor/pytensor/tensor/extra_ops.py
Lines 604 to 606 in 4c8c8b6
2. Similarly, if `axis` is ``None``, only dimensions known to be broadcastable will be | |
removed, even if there are more dimensions that happen to be broadcastable when | |
the variable is evaluated. |
The first point is actually not true anymore, so don't copy it
I replied in the comment. It's not correct. Symbolic inputs have to show up in import pytensor
import pytensor.xtensor as px
x = px.xtensor("x", shape=(2,), dims=("a",))
b_size = px.xtensor("b_size", shape=(), dims=())
y = y.expand_dims(b=b_size)
y.eval({x: np.array([0, 1]), b_size: np.array(10)}) If you try this right now you may get an error or a silent bug, because |
@ricardoV94 squeeze is ready for another look. expand_dims is still not ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did another pass on the Squeeze functionality.
I suggested removing the None case from the Op level and left some other minor comments about it.
I think the tests right now are a bit overkill / redundant / messy. I suggest grouping in different functions the following things:
- Tests with explicit squeeze dim (single, multiple, order independent)
- Tests with implicit None dim (including the case that at runtime deviates from xarray and as documented)
- Tests for errors raised by the Op at creation or runtime
pytensor/xtensor/shape.py
Outdated
to be size 1 at runtime. | ||
""" | ||
|
||
__props__ = ("dim",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: call the Op prop dims
instead of dim
(still use dim
in the user facing functions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, but it means that dim
and dims
are all over the place now. Worth it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean they are all over the place now? Why is that?
@ricardoV94 squeeze is ready for another look, and expand_dims is ready, too |
Adding squeeze and expand_dims
📚 Documentation preview 📚: https://pytensor--1434.org.readthedocs.build/en/1434/