-
Notifications
You must be signed in to change notification settings - Fork 135
Add squeeze for labeled tensors #1434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b3e859c
d824870
a076966
e2ffe1c
7489489
22adb6f
a7e2bf8
332139d
4b2f0f7
2120b1a
2bb1fce
dd13fc7
7a308b9
9c1a0b7
1024798
260b9b6
915a368
05dac9e
98d297e
3202c4c
8d4fdd5
f000bbb
2de9566
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,17 @@ | |
from itertools import chain, combinations | ||
|
||
import numpy as np | ||
import pytest | ||
from xarray import DataArray | ||
from xarray import concat as xr_concat | ||
|
||
from pytensor.xtensor.shape import concat, stack, transpose, unstack | ||
from pytensor.xtensor.shape import ( | ||
concat, | ||
squeeze, | ||
stack, | ||
transpose, | ||
unstack, | ||
) | ||
from pytensor.xtensor.type import xtensor | ||
from tests.xtensor.util import ( | ||
xr_arange_like, | ||
|
@@ -21,6 +28,9 @@ | |
) | ||
|
||
|
||
pytest.importorskip("xarray") | ||
AllenDowney marked this conversation as resolved.
Show resolved
Hide resolved
AllenDowney marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def powerset(iterable, min_group_size=0): | ||
"Subsequences of the iterable from shortest to longest." | ||
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) | ||
|
@@ -254,3 +264,109 @@ def test_concat_scalar(): | |
res = fn(x1_test, x2_test) | ||
expected_res = xr_concat([x1_test, x2_test], dim="new_dim") | ||
xr_assert_allclose(res, expected_res) | ||
|
||
|
||
def test_squeeze_explicit_dims(): | ||
"""Test squeeze with explicit dimension(s).""" | ||
|
||
# Single dimension | ||
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) | ||
y1 = squeeze(x1, "country") | ||
fn1 = xr_function([x1], y1) | ||
x1_test = xr_arange_like(x1) | ||
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country")) | ||
|
||
# Multiple dimensions | ||
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) | ||
y2 = squeeze(x2, ["b", "c"]) | ||
fn2 = xr_function([x2], y2) | ||
x2_test = xr_arange_like(x2) | ||
xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"])) | ||
|
||
# Order independence | ||
x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1)) | ||
y3a = squeeze(x3, ["b", "c"]) | ||
y3b = squeeze(x3, ["c", "b"]) | ||
fn3a = xr_function([x3], y3a) | ||
fn3b = xr_function([x3], y3b) | ||
x3_test = xr_arange_like(x3) | ||
xr_assert_allclose(fn3a(x3_test), fn3b(x3_test)) | ||
Comment on lines
+286
to
+293
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combine this with the previous test. Test both of them against xarray. You don't need one function per case, the function can have two outputs, which should be a faster test, as it only trigger the compilation machinery once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have too many questions about this comment. If you want to make this change after merging, that might be more efficient than explaining. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your previous check was already testing a squeeze of multiple dimensions, so you can combine this which also checks multiple dimensions + the fact that order doesn't matter. This test is a superset of the previous one. Then the point about combining multiple outputs is to do |
||
|
||
# Redundant dimensions | ||
y3c = squeeze(x3, ["b", "b"]) | ||
fn3c = xr_function([x3], y3c) | ||
xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"])) | ||
|
||
# Empty list = no-op | ||
y3d = squeeze(x3, []) | ||
fn3d = xr_function([x3], y3d) | ||
xr_assert_allclose(fn3d(x3_test), x3_test) | ||
|
||
|
||
def test_squeeze_implicit_dims(): | ||
AllenDowney marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Test squeeze with implicit dim=None (all size-1 dimensions).""" | ||
|
||
# All dimensions size 1 | ||
x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1)) | ||
y1 = squeeze(x1) | ||
fn1 = xr_function([x1], y1) | ||
x1_test = xr_arange_like(x1) | ||
xr_assert_allclose(fn1(x1_test), x1_test.squeeze()) | ||
|
||
# No dimensions size 1 = no-op | ||
x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4)) | ||
y2 = squeeze(x2) | ||
fn2 = xr_function([x2], y2) | ||
x2_test = xr_arange_like(x2) | ||
xr_assert_allclose(fn2(x2_test), x2_test) | ||
|
||
# Symbolic shape where runtime shape is 1 → should squeeze | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't call these symbolic shapes. They are just unknown static shapes. It's confusing label |
||
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown | ||
y3 = squeeze(x3, "b") | ||
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3))) | ||
fn3 = xr_function([x3], y3) | ||
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b")) | ||
|
||
# Mixed static + symbolic shapes, where symbolic shape is 1 | ||
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3)) | ||
y4 = squeeze(x4, "b") | ||
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3))) | ||
fn4 = xr_function([x4], y4) | ||
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) | ||
Comment on lines
+330
to
+335
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is not interesting, remove? Or what was the reason for it? |
||
|
||
""" | ||
This test documents that we intentionally don't squeeze dimensions with symbolic shapes | ||
(static_shape=None) even when they are 1 at runtime, while xarray does squeeze them. | ||
""" | ||
# Create a tensor with a symbolic dimension that will be 1 at runtime | ||
x = xtensor("x", dims=("a", "b", "c")) # shape unknown | ||
y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions | ||
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3))) | ||
fn = xr_function([x], y) | ||
res = fn(x_test) | ||
|
||
# Our implementation should not squeeze the symbolic dimension | ||
assert "b" in res.dims | ||
# While xarray would squeeze it | ||
assert "b" not in x_test.squeeze().dims | ||
|
||
|
||
def test_squeeze_errors(): | ||
"""Test error cases for squeeze.""" | ||
|
||
# Non-existent dimension | ||
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) | ||
with pytest.raises(ValueError, match="Dimension .* not found"): | ||
squeeze(x1, "time") | ||
|
||
# Dimension size > 1 | ||
with pytest.raises(ValueError, match="has static size .* not 1"): | ||
squeeze(x1, "city") | ||
|
||
# Symbolic shape: dim is not 1 at runtime → should raise | ||
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown | ||
y2 = squeeze(x2, "b") | ||
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3))) | ||
fn2 = xr_function([x2], y2) | ||
with pytest.raises(Exception): | ||
fn2(x2_test) |
Uh oh!
There was an error while loading. Please reload this page.