Skip to content

Commit 07b715c

Browse files
Split jax tests into their own workflow
1 parent 78d15f4 commit 07b715c

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed

.github/workflows/jaxtests.yml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
name: jax-sampling
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [master]
7+
8+
jobs:
9+
pytest:
10+
strategy:
11+
matrix:
12+
os: [ubuntu-latest]
13+
floatx: [float64]
14+
test-subset:
15+
- pymc3/tests/test_sampling_jax.py
16+
fail-fast: false
17+
runs-on: ${{ matrix.os }}
18+
env:
19+
TEST_SUBSET: ${{ matrix.test-subset }}
20+
THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native'
21+
defaults:
22+
run:
23+
shell: bash -l {0}
24+
steps:
25+
- uses: actions/checkout@v2
26+
- name: Cache conda
27+
uses: actions/cache@v1
28+
env:
29+
# Increase this value to reset cache if environment-dev-py39.yml has not changed
30+
CACHE_NUMBER: 0
31+
with:
32+
path: ~/conda_pkgs_dir
33+
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
34+
hashFiles('conda-envs/environment-dev-py39.yml') }}
35+
- name: Cache multiple paths
36+
uses: actions/cache@v2
37+
env:
38+
# Increase this value to reset cache if requirements.txt has not changed
39+
CACHE_NUMBER: 0
40+
with:
41+
path: |
42+
~/.cache/pip
43+
$RUNNER_TOOL_CACHE/Python/*
44+
~\AppData\Local\pip\Cache
45+
key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{
46+
hashFiles('requirements.txt') }}
47+
- uses: conda-incubator/setup-miniconda@v2
48+
with:
49+
activate-environment: pymc3-dev-py39
50+
channel-priority: strict
51+
environment-file: conda-envs/environment-dev-py39.yml
52+
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
53+
- name: Install pymc3
54+
run: |
55+
conda activate pymc3-dev-py39
56+
pip install -e .
57+
python --version
58+
- name: Install jax specific dependencies
59+
run: |
60+
conda activate pymc3-dev-py39
61+
pip install numpyro tensorflow_probability
62+
- name: Run tests
63+
run: |
64+
python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET

.github/workflows/pytest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
--ignore=pymc3/tests/test_quadpotential.py
2828
--ignore=pymc3/tests/test_random.py
2929
--ignore=pymc3/tests/test_sampling.py
30+
--ignore=pymc3/tests/test_sampling_jax.py
3031
--ignore=pymc3/tests/test_shape_handling.py
3132
--ignore=pymc3/tests/test_shared.py
3233
--ignore=pymc3/tests/test_smc.py

scripts/check_all_tests_are_covered.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
from pathlib import Path
1313

1414
if __name__ == "__main__":
15-
pytest_ci_job = Path(".github") / "workflows/pytest.yml"
16-
txt = pytest_ci_job.read_text()
17-
ignored_tests = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
18-
non_ignored_tests = set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt))
15+
testing_workflows = ["jaxtests.yml", "pytest.yml"]
16+
ignored = set()
17+
non_ignored = set()
18+
for wfyml in testing_workflows:
19+
pytest_ci_job = Path(".github") / "workflows" / wfyml
20+
txt = pytest_ci_job.read_text()
21+
ignored = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
22+
non_ignored = non_ignored.union(set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt)))
1923
assert (
20-
ignored_tests <= non_ignored_tests
21-
), f"The following tests are ignored by the first job but not run by the others: {ignored_tests.difference(non_ignored_tests)}"
24+
ignored <= non_ignored
25+
), f"The following tests are ignored by the first job but not run by the others: {ignored.difference(non_ignored)}"
2226
assert (
23-
ignored_tests >= non_ignored_tests
24-
), f"The following tests are run by multiple jobs: {non_ignored_tests.difference(ignored_tests)}"
27+
ignored >= non_ignored
28+
), f"The following tests are run by multiple jobs: {non_ignored.difference(ignored)}"

0 commit comments

Comments
 (0)