Skip to content

Commit 6f13cfa

Browse files
authored
Merge branch 'pymc-devs:main' into main
2 parents e7003bf + 5c8afae commit 6f13cfa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1905
-737
lines changed

.github/workflows/pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
fetch-depth: 0
5050

5151
- name: Build wheels
52-
uses: pypa/cibuildwheel@v2.18.1
52+
uses: pypa/cibuildwheel@v2.19.1
5353

5454
- uses: actions/upload-artifact@v3
5555
with:
@@ -88,7 +88,7 @@ jobs:
8888
name: artifact
8989
path: dist
9090

91-
- uses: pypa/gh-action-pypi-publish@v1.8.14
91+
- uses: pypa/gh-action-pypi-publish@v1.9.0
9292
with:
9393
user: __token__
9494
password: ${{ secrets.pypi_password }}

.github/workflows/test.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ jobs:
7676
float32: [0, 1]
7777
install-numba: [0]
7878
install-jax: [0]
79+
install-torch: [0]
7980
part:
8081
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
8182
- "tests/scan"
@@ -116,6 +117,11 @@ jobs:
116117
fast-compile: 0
117118
float32: 0
118119
part: "tests/link/jax"
120+
- install-torch: 1
121+
python-version: "3.10"
122+
fast-compile: 0
123+
float32: 0
124+
part: "tests/link/pytorch"
119125
steps:
120126
- uses: actions/checkout@v4
121127
with:
@@ -142,9 +148,12 @@ jobs:
142148
- name: Install dependencies
143149
shell: micromamba-shell {0}
144150
run: |
151+
145152
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
146153
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
147154
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
155+
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
156+
148157
pip install -e ./
149158
micromamba list && pip freeze
150159
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
@@ -153,6 +162,7 @@ jobs:
153162
PYTHON_VERSION: ${{ matrix.python-version }}
154163
INSTALL_NUMBA: ${{ matrix.install-numba }}
155164
INSTALL_JAX: ${{ matrix.install-jax }}
165+
INSTALL_TORCH: ${{ matrix.install-torch}}
156166

157167
- name: Run tests
158168
shell: micromamba-shell {0}
@@ -199,7 +209,7 @@ jobs:
199209
- name: Install dependencies
200210
shell: micromamba-shell {0}
201211
run: |
202-
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
212+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
203213
pip install -e ./
204214
micromamba list && pip freeze
205215
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
@@ -268,3 +278,4 @@ jobs:
268278
directory: ./coverage/
269279
fail_ci_if_error: true
270280
token: ${{ secrets.CODECOV_TOKEN }}
281+

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
)$
2323
- id: check-merge-conflict
2424
- repo: https://github.com/astral-sh/ruff-pre-commit
25-
rev: v0.4.8
25+
rev: v0.4.10
2626
hooks:
2727
- id: ruff
2828
args: ["--fix", "--output-format=full"]

environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ channels:
99
dependencies:
1010
- python>=3.10
1111
- compilers
12-
- numpy>=1.17.0
13-
- scipy>=0.14
12+
- numpy>=1.17.0,<2
13+
- scipy>=0.14,<1.14.0
1414
- filelock
1515
- etuples
1616
- logical-unification

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ keywords = [
4747
]
4848
dependencies = [
4949
"setuptools>=59.0.0",
50-
"scipy>=0.14",
50+
"scipy>=0.14,<1.14",
5151
"numpy>=1.17.0,<2",
5252
"filelock",
5353
"etuples",

pytensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:
108108

109109

110110
@singledispatch
111-
def _as_symbolic(x, **kwargs) -> Variable:
111+
def _as_symbolic(x: Any, **kwargs) -> Variable:
112112
from pytensor.tensor import as_tensor_variable
113113

114114
return as_tensor_variable(x, **kwargs)

pytensor/compile/mode.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2929
from pytensor.link.jax.linker import JAXLinker
3030
from pytensor.link.numba.linker import NumbaLinker
31+
from pytensor.link.pytorch.linker import PytorchLinker
3132
from pytensor.link.vm import VMLinker
3233

3334

@@ -47,6 +48,7 @@
4748
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
4849
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
4950
"jax": JAXLinker(),
51+
"pytorch": PytorchLinker(),
5052
"numba": NumbaLinker(),
5153
}
5254

@@ -460,6 +462,18 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
460462
],
461463
),
462464
)
465+
PYTORCH = Mode(
466+
PytorchLinker(),
467+
RewriteDatabaseQuery(
468+
include=["fast_run"],
469+
exclude=[
470+
"cxx_only",
471+
"BlasOpt",
472+
"fusion",
473+
"inplace",
474+
],
475+
),
476+
)
463477
NUMBA = Mode(
464478
NumbaLinker(),
465479
RewriteDatabaseQuery(
@@ -474,6 +488,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
474488
"FAST_RUN": FAST_RUN,
475489
"JAX": JAX,
476490
"NUMBA": NUMBA,
491+
"PYTORCH": PYTORCH,
477492
}
478493

479494
instantiated_default_mode = None

pytensor/graph/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,8 +1302,8 @@ def clone_node_and_cache(
13021302

13031303

13041304
def clone_get_equiv(
1305-
inputs: Sequence[Variable],
1306-
outputs: Sequence[Variable],
1305+
inputs: Iterable[Variable],
1306+
outputs: Reversible[Variable],
13071307
copy_inputs: bool = True,
13081308
copy_orphans: bool = True,
13091309
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]

pytensor/link/basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,10 @@ def create_thunk_inputs(self, storage_map: dict[Variable, list[Any]]) -> list[An
600600
def jit_compile(self, fn: Callable) -> Callable:
601601
"""JIT compile a converted ``FunctionGraph``."""
602602

603+
def input_filter(self, inp: Any) -> Any:
604+
"""Apply a filter to the data input."""
605+
return inp
606+
603607
def output_filter(self, var: Variable, out: Any) -> Any:
604608
"""Apply a filter to the data output by a JITed function call."""
605609
return out
@@ -657,7 +661,7 @@ def thunk(
657661
thunk_inputs=thunk_inputs,
658662
thunk_outputs=thunk_outputs,
659663
):
660-
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
664+
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
661665

662666
for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
663667
compute_map[o_var][0] = True

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
1616
import pytensor.link.jax.dispatch.blockwise
17+
import pytensor.link.jax.dispatch.sort
1718

1819
# isort: on

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
11
import jax
22

33
from pytensor.link.jax.dispatch.basic import jax_funcify
4-
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular
4+
from pytensor.tensor.slinalg import (
5+
BlockDiagonal,
6+
Cholesky,
7+
Eigvalsh,
8+
Solve,
9+
SolveTriangular,
10+
)
11+
12+
13+
@jax_funcify.register(Eigvalsh)
14+
def jax_funcify_Eigvalsh(op, **kwargs):
15+
if op.lower:
16+
UPLO = "L"
17+
else:
18+
UPLO = "U"
19+
20+
def eigvalsh(a, b):
21+
if b is not None:
22+
raise NotImplementedError(
23+
"jax.numpy.linalg.eigvalsh does not support generalized eigenvector problems (b != None)"
24+
)
25+
return jax.numpy.linalg.eigvalsh(a, UPLO=UPLO)
26+
27+
return eigvalsh
528

629

730
@jax_funcify.register(Cholesky)

pytensor/link/jax/dispatch/sort.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from jax import numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.sort import ArgSortOp, SortOp
5+
6+
7+
@jax_funcify.register(SortOp)
8+
def jax_funcify_Sort(op, **kwargs):
9+
stable = op.kind == "stable"
10+
11+
def sort(arr, axis):
12+
return jnp.sort(arr, axis=axis, stable=stable)
13+
14+
return sort
15+
16+
17+
@jax_funcify.register(ArgSortOp)
18+
def jax_funcify_ArgSort(op, **kwargs):
19+
stable = op.kind == "stable"
20+
21+
def argsort(arr, axis):
22+
return jnp.argsort(arr, axis=axis, stable=stable)
23+
24+
return argsort

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,36 +31,20 @@
3131
"""
3232

3333

34-
def subtensor_assert_indices_jax_compatible(node, idx_list):
35-
from pytensor.graph.basic import Constant
36-
from pytensor.tensor.variable import TensorVariable
37-
38-
ilist = indices_from_subtensor(node.inputs[1:], idx_list)
39-
for idx in ilist:
40-
if isinstance(idx, TensorVariable):
41-
if idx.type.dtype == "bool":
42-
raise NotImplementedError(BOOLEAN_MASK_ERROR)
43-
elif isinstance(idx, slice):
44-
for slice_arg in (idx.start, idx.stop, idx.step):
45-
if slice_arg is not None and not isinstance(slice_arg, Constant):
46-
raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)
47-
48-
4934
@jax_funcify.register(Subtensor)
5035
@jax_funcify.register(AdvancedSubtensor)
5136
@jax_funcify.register(AdvancedSubtensor1)
5237
def jax_funcify_Subtensor(op, node, **kwargs):
5338
idx_list = getattr(op, "idx_list", None)
54-
subtensor_assert_indices_jax_compatible(node, idx_list)
5539

56-
def subtensor_constant(x, *ilists):
40+
def subtensor(x, *ilists):
5741
indices = indices_from_subtensor(ilists, idx_list)
5842
if len(indices) == 1:
5943
indices = indices[0]
6044

6145
return x.__getitem__(indices)
6246

63-
return subtensor_constant
47+
return subtensor
6448

6549

6650
@jax_funcify.register(IncSubtensor)

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
2424
from pytensor.tensor.shape import Shape_i
25-
from pytensor.tensor.sort import SortOp
2625

2726

2827
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
@@ -206,11 +205,3 @@ def tri(*args):
206205
return jnp.tri(*args, dtype=op.dtype)
207206

208207
return tri
209-
210-
211-
@jax_funcify.register(SortOp)
212-
def jax_funcify_Sort(op, **kwargs):
213-
def sort(arr, axis):
214-
return jnp.sort(arr, axis=axis)
215-
216-
return sort

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
import pytensor.link.numba.dispatch.scan
1212
import pytensor.link.numba.dispatch.sparse
1313
import pytensor.link.numba.dispatch.slinalg
14+
import pytensor.link.numba.dispatch.subtensor
1415

1516
# isort: on

0 commit comments

Comments
 (0)