Skip to content

Commit 692c53c

Browse files
committed
Test xtensor module
1 parent 1ab3e09 commit 692c53c

File tree

6 files changed

+23
-4
lines changed

6 files changed

+23
-4
lines changed

.github/workflows/test.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ jobs:
8282
install-numba: [0]
8383
install-jax: [0]
8484
install-torch: [0]
85+
install-xarray: [0]
8586
part:
8687
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
8788
- "tests/scan"
@@ -115,6 +116,7 @@ jobs:
115116
install-numba: 0
116117
install-jax: 0
117118
install-torch: 0
119+
install-xarray: 0
118120
- install-numba: 1
119121
os: "ubuntu-latest"
120122
python-version: "3.10"
@@ -150,6 +152,13 @@ jobs:
150152
fast-compile: 0
151153
float32: 0
152154
part: "tests/link/pytorch"
155+
- install-xarray: 1
156+
os: "ubuntu-latest"
157+
python-version: "3.13"
158+
numpy-version: ">=2.0"
159+
fast-compile: 0
160+
float32: 0
161+
part: "tests/xtensor"
153162
- os: macos-15
154163
python-version: "3.13"
155164
numpy-version: ">=2.0"
@@ -196,6 +205,7 @@ jobs:
196205
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
197206
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
198207
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208+
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
199209
pip install pytest-sphinx
200210
201211
pip install -e ./
@@ -212,6 +222,7 @@ jobs:
212222
INSTALL_NUMBA: ${{ matrix.install-numba }}
213223
INSTALL_JAX: ${{ matrix.install-jax }}
214224
INSTALL_TORCH: ${{ matrix.install-torch}}
225+
INSTALL_XARRAY: ${{ matrix.install-xarray }}
215226
OS: ${{ matrix.os}}
216227

217228
- name: Run tests

tests/xtensor/test_linalg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# ruff: noqa: E402
2-
32
import pytest
43

54

tests/xtensor/test_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
# ruff: noqa: E402
12
import pytest
23

34

4-
# ruff: noqa: E402
55
pytest.importorskip("xarray") #
66

77
import numpy as np

tests/xtensor/test_reduction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
# ruff: noqa: E402
12
import pytest
23

4+
5+
pytest.importorskip("xarray")
6+
37
from pytensor.xtensor.type import xtensor
48
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
59

tests/xtensor/test_shape.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# ruff: noqa: E402
2-
import re
3-
42
import pytest
53

64

75
pytest.importorskip("xarray")
86

7+
import re
98
from itertools import chain, combinations
109

1110
import numpy as np

tests/xtensor/util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# ruff: noqa: E402
2+
import pytest
3+
4+
5+
pytest.importorskip("xarray")
6+
17
import numpy as np
28
from xarray import DataArray
39
from xarray.testing import assert_allclose

0 commit comments

Comments
 (0)