Skip to content

Commit 781527f

Browse files
jessegrabowskizaxtax
authored andcommitted
Implement Einsum as OpFromGraph
Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Rob Zinkov <[email protected]>
1 parent afc1a6c commit 781527f

File tree

13 files changed

+601
-10
lines changed

13 files changed

+601
-10
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ jobs:
149149
shell: micromamba-shell {0}
150150
run: |
151151
152-
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
152+
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy opt_einsum pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
153153
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
154154
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
155155
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
@@ -209,7 +209,7 @@ jobs:
209209
- name: Install dependencies
210210
shell: micromamba-shell {0}
211211
run: |
212-
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
212+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy opt_einsum pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
213213
pip install -e ./
214214
micromamba list && pip freeze
215215
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
- compilers
1212
- numpy>=1.17.0,<2
1313
- scipy>=0.14,<1.14.0
14+
- opt_einsum
1415
- filelock
1516
- etuples
1617
- logical-unification

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies = [
4949
"setuptools>=59.0.0",
5050
"scipy>=0.14,<1.14",
5151
"numpy>=1.17.0,<2",
52+
"opt_einsum",
5253
"filelock",
5354
"etuples",
5455
"logical-unification",

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
1616
import pytensor.link.jax.dispatch.blockwise
17+
import pytensor.link.jax.dispatch.einsum
1718
import pytensor.link.jax.dispatch.sort
1819

1920
# isort: on

pytensor/link/jax/dispatch/einsum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.einsum import Einsum
5+
6+
7+
@jax_funcify.register(Einsum)
8+
def jax_funcify_Einsum(op, **kwargs):
9+
subscripts = op.subscripts
10+
optimize = op.optimize
11+
12+
def einsum(*operands):
13+
return jnp.einsum(subscripts, *operands, optimize=optimize)
14+
15+
return einsum

pytensor/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
153153
from pytensor.tensor.functional import vectorize
154154
# isort: on
155155

156+
from pytensor.tensor.einsum import einsum
157+
156158

157159
__all__ = ["random"] # noqa: F405

pytensor/tensor/basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,7 +1997,12 @@ def transpose(x, axes=None):
19971997
_x = as_tensor_variable(x)
19981998

19991999
if axes is None:
2000-
axes = list(range((_x.type.ndim - 1), -1, -1))
2000+
axes = tuple(range((_x.type.ndim - 1), -1, -1))
2001+
2002+
if tuple(axes) == tuple(range(len(axes))):
2003+
# No-op
2004+
return _x
2005+
20012006
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
20022007

20032008
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
@@ -3976,6 +3981,10 @@ def moveaxis(
39763981
source = normalize_axis_tuple(source, a.ndim, "source")
39773982
destination = normalize_axis_tuple(destination, a.ndim, "destination")
39783983

3984+
if source == destination:
3985+
# It's a no-op
3986+
return a
3987+
39793988
if len(source) != len(destination):
39803989
raise ValueError(
39813990
"`source` and `destination` arguments must have the same number of elements"
@@ -4290,9 +4299,7 @@ def atleast_Nd(
42904299
atleast_3d = partial(atleast_Nd, n=3)
42914300

42924301

4293-
def expand_dims(
4294-
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
4295-
) -> TensorVariable:
4302+
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
42964303
"""Expand the shape of an array.
42974304
42984305
Insert a new axis that will appear at the `axis` position in the expanded
@@ -4311,7 +4318,7 @@ def expand_dims(
43114318
"""
43124319
a = as_tensor(a)
43134320

4314-
if not isinstance(axis, tuple | list):
4321+
if not isinstance(axis, Sequence):
43154322
axis = (axis,)
43164323

43174324
out_ndim = len(axis) + a.ndim

0 commit comments

Comments
 (0)