Skip to content

Commit 236a3df

Browse files
ricardoV94maresb
andauthored
Fix failing CI in Python 3.8 and install numba/jax on specific runs (#326)
* Add failed assertion message in CI * Pin numpy upper bound in numba install numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but not numpy, even though scipy 1.7 requires numpy<1.23. When installing PyTensor next, pip installs a lower version of numpy via the PyPI. * Run numba and jax tests in separate jobs --------- Co-authored-by: Ben Mares <[email protected]>
1 parent a8e0adc commit 236a3df

File tree

7 files changed

+54
-8
lines changed

7 files changed

+54
-8
lines changed

.github/workflows/test.yml

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ jobs:
7373
python-version: ["3.8", "3.11"]
7474
fast-compile: [0,1]
7575
float32: [0,1]
76-
install-numba: [1]
76+
install-numba: [0]
77+
install-jax: [0]
7778
part:
7879
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
7980
- "tests/scan"
@@ -93,6 +94,27 @@ jobs:
9394
part: "tests/tensor/test_math.py"
9495
- fast-compile: 1
9596
float32: 1
97+
include:
98+
- install-numba: 1
99+
python-version: "3.8"
100+
fast-compile: 0
101+
float32: 0
102+
part: "tests/link/numba"
103+
- install-numba: 1
104+
python-version: "3.11"
105+
fast-compile: 0
106+
float32: 0
107+
part: "tests/link/numba"
108+
- install-jax: 1
109+
python-version: "3.8"
110+
fast-compile: 0
111+
float32: 0
112+
part: "tests/link/jax"
113+
- install-jax: 1
114+
python-version: "3.11"
115+
fast-compile: 0
116+
float32: 0
117+
part: "tests/link/jax"
96118
steps:
97119
- uses: actions/checkout@v3
98120
with:
@@ -118,15 +140,20 @@ jobs:
118140
shell: bash -l {0}
119141
run: |
120142
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
121-
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
122-
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
143+
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but
144+
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
145+
# PyTensor next, pip installs a lower version of numpy via the PyPI.
146+
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
147+
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
148+
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
123149
pip install -e ./
124150
mamba list && pip freeze
125151
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
126-
python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")'
152+
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
127153
env:
128154
PYTHON_VERSION: ${{ matrix.python-version }}
129155
INSTALL_NUMBA: ${{ matrix.install-numba }}
156+
INSTALL_JAX: ${{ matrix.install-jax }}
130157

131158
- name: Run tests
132159
shell: bash -l {0}
@@ -175,7 +202,7 @@ jobs:
175202
pip install -e ./
176203
mamba list && pip freeze
177204
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
178-
python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")'
205+
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
179206
env:
180207
PYTHON_VERSION: 3.9
181208
- name: Download previous benchmark data

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ per-file-ignores =
99
pytensor/link/jax/jax_dispatch.py:E402,F403,F401
1010
pytensor/link/jax/jax_linker.py:E402,F403,F401
1111
pytensor/sparse/sandbox/sp2.py:F401
12+
tests/link/jax/*.py:E402
13+
tests/link/numba/*.py:E402
1214
tests/tensor/test_math_scipy.py:E402
1315
tests/sparse/test_basic.py:E402
1416
tests/sparse/test_opt.py:E402

tests/link/jax/test_tensor_basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
import jax.errors
21
import numpy as np
32
import pytest
43

4+
5+
jax = pytest.importorskip("jax")
6+
import jax.errors
7+
58
import pytensor
69
import pytensor.tensor.basic as at
710
from pytensor.configdefaults import config

tests/link/numba/test_basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
44
from unittest import mock
55

6-
import numba
76
import numpy as np
87
import pytest
98

9+
10+
numba = pytest.importorskip("numba")
11+
1012
import pytensor.scalar as aes
1113
import pytensor.scalar.math as aesm
1214
import pytensor.tensor as at

tests/link/numba/test_cython_support.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import numpy as np
22
import pytest
33
import scipy.special.cython_special
4+
5+
6+
numba = pytest.importorskip("numba")
7+
8+
49
from numba.types import float32, float64, int32, int64
510

611
from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function

tests/link/numba/test_performance.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import numpy as np
44
import pytest
55

6+
7+
pytest.importorskip("numba")
8+
69
import pytensor.tensor as aet
710
from pytensor import config
811
from pytensor.compile.function import function
@@ -70,4 +73,5 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
7073
mean_numpy_time = np.mean(numpy_times)
7174
# mean_c_time = np.mean(c_times)
7275

76+
# FIXME: Why are we asserting >=? Numba could be doing worse than numpy!
7377
assert mean_numba_time / mean_numpy_time >= 0.75

tests/link/numba/test_sparse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
import numba
21
import numpy as np
32
import pytest
43
import scipy as sp
54

5+
6+
numba = pytest.importorskip("numba")
7+
8+
69
# Make sure the Numba customizations are loaded
710
import pytensor.link.numba.dispatch.sparse # noqa: F401
811
from pytensor import config

0 commit comments

Comments
 (0)