Skip to content

Commit d40609a

Browse files
dcherianmax-sixty
andauthored
Use opt_einsum by default if installed. (#8373)
* Use `opt_einsum` by default if installed. Closes #7764 Closes #8017 * docstring update * _ * _ Co-authored-by: Maximilian Roos <[email protected]> * Update xarray/core/computation.py Co-authored-by: Maximilian Roos <[email protected]> * Fix docs? * Add use_opt_einsum option. * mypy ignore * one more test ignore * Disable navigation_with_keys * remove intersphinx * One more skip --------- Co-authored-by: Maximilian Roos <[email protected]>
1 parent bb489fa commit d40609a

File tree

9 files changed

+52
-15
lines changed

9 files changed

+52
-15
lines changed

ci/install-upstream-wheels.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ python -m pip install \
4545
git+https://github.com/intake/filesystem_spec \
4646
git+https://github.com/SciTools/nc-time-axis \
4747
git+https://github.com/xarray-contrib/flox \
48-
git+https://github.com/h5netcdf/h5netcdf
48+
git+https://github.com/h5netcdf/h5netcdf \
49+
git+https://github.com/dgasmith/opt_einsum

ci/requirements/environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies:
2626
- numbagg
2727
- numexpr
2828
- numpy
29+
- opt_einsum
2930
- packaging
3031
- pandas
3132
- pint<0.21

doc/conf.py

+2
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@
237237
use_repository_button=True,
238238
use_issues_button=True,
239239
home_page_in_toc=False,
240+
navigation_with_keys=False,
240241
extra_footer="""<p>Xarray is a fiscally sponsored project of <a href="https://numfocus.org">NumFOCUS</a>,
241242
a nonprofit dedicated to supporting the open-source scientific computing community.<br>
242243
Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a></p>""",
@@ -327,6 +328,7 @@
327328
"sparse": ("https://sparse.pydata.org/en/latest/", None),
328329
"cubed": ("https://tom-e-white.com/cubed/", None),
329330
"datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None),
331+
# "opt_einsum": ("https://dgasmith.github.io/opt_einsum/", None),
330332
}
331333

332334

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ v2023.10.2 (unreleased)
2222
New Features
2323
~~~~~~~~~~~~
2424

25+
- Use `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ for :py:func:`xarray.dot` by default if installed.
26+
By `Deepak Cherian <https://github.com/dcherian>`_. (:issue:`7764`, :pull:`8373`).
2527

2628
Breaking changes
2729
~~~~~~~~~~~~~~~~

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ source-code = "https://github.com/pydata/xarray"
3838
dask = "xarray.core.daskmanager:DaskManager"
3939

4040
[project.optional-dependencies]
41-
accel = ["scipy", "bottleneck", "numbagg", "flox"]
41+
accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
4242
complete = ["xarray[accel,io,parallel,viz]"]
4343
io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"]
4444
parallel = ["dask[complete]"]
@@ -106,6 +106,7 @@ module = [
106106
"numbagg.*",
107107
"netCDF4.*",
108108
"netcdftime.*",
109+
"opt_einsum.*",
109110
"pandas.*",
110111
"pooch.*",
111112
"PseudoNetCDF.*",

xarray/core/computation.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1690,8 +1690,8 @@ def dot(
16901690
dims: Dims = None,
16911691
**kwargs: Any,
16921692
):
1693-
"""Generalized dot product for xarray objects. Like np.einsum, but
1694-
provides a simpler interface based on array dimensions.
1693+
"""Generalized dot product for xarray objects. Like ``np.einsum``, but
1694+
provides a simpler interface based on array dimension names.
16951695
16961696
Parameters
16971697
----------
@@ -1701,13 +1701,24 @@ def dot(
17011701
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
17021702
If not specified, then all the common dimensions are summed over.
17031703
**kwargs : dict
1704-
Additional keyword arguments passed to numpy.einsum or
1705-
dask.array.einsum
1704+
Additional keyword arguments passed to ``numpy.einsum`` or
1705+
``dask.array.einsum``
17061706
17071707
Returns
17081708
-------
17091709
DataArray
17101710
1711+
See Also
1712+
--------
1713+
numpy.einsum
1714+
dask.array.einsum
1715+
opt_einsum.contract
1716+
1717+
Notes
1718+
-----
1719+
We recommend installing the optional ``opt_einsum`` package, or alternatively passing ``optimize=True``,
1720+
which is passed through to ``np.einsum``, and works for most array backends.
1721+
17111722
Examples
17121723
--------
17131724
>>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"])

xarray/core/duck_array_ops.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from numpy import any as array_any # noqa
1919
from numpy import ( # noqa
2020
around, # noqa
21-
einsum,
2221
gradient,
2322
isclose,
2423
isin,
@@ -48,6 +47,17 @@ def get_array_namespace(x):
4847
return np
4948

5049

50+
def einsum(*args, **kwargs):
51+
from xarray.core.options import OPTIONS
52+
53+
if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"):
54+
import opt_einsum
55+
56+
return opt_einsum.contract(*args, **kwargs)
57+
else:
58+
return np.einsum(*args, **kwargs)
59+
60+
5161
def _dask_or_eager_func(
5262
name,
5363
eager_module=np,

xarray/core/options.py

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"warn_for_unclosed_files",
2929
"use_bottleneck",
3030
"use_numbagg",
31+
"use_opt_einsum",
3132
"use_flox",
3233
]
3334

@@ -52,6 +53,7 @@ class T_Options(TypedDict):
5253
use_bottleneck: bool
5354
use_flox: bool
5455
use_numbagg: bool
56+
use_opt_einsum: bool
5557

5658

5759
OPTIONS: T_Options = {
@@ -75,6 +77,7 @@ class T_Options(TypedDict):
7577
"use_bottleneck": True,
7678
"use_flox": True,
7779
"use_numbagg": True,
80+
"use_opt_einsum": True,
7881
}
7982

8083
_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"])
@@ -102,6 +105,7 @@ def _positive_integer(value: int) -> bool:
102105
"keep_attrs": lambda choice: choice in [True, False, "default"],
103106
"use_bottleneck": lambda value: isinstance(value, bool),
104107
"use_numbagg": lambda value: isinstance(value, bool),
108+
"use_opt_einsum": lambda value: isinstance(value, bool),
105109
"use_flox": lambda value: isinstance(value, bool),
106110
"warn_for_unclosed_files": lambda value: isinstance(value, bool),
107111
}
@@ -237,6 +241,8 @@ class set_options:
237241
use_numbagg : bool, default: True
238242
Whether to use ``numbagg`` to accelerate reductions.
239243
Takes precedence over ``use_bottleneck`` when both are True.
244+
use_opt_einsum : bool, default: True
245+
Whether to use ``opt_einsum`` to accelerate dot products.
240246
warn_for_unclosed_files : bool, default: False
241247
Whether or not to issue a warning when unclosed files are
242248
deallocated. This is mostly useful for debugging.

xarray/tests/test_units.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1502,10 +1502,11 @@ def test_dot_dataarray(dtype):
15021502
data_array = xr.DataArray(data=array1, dims=("x", "y"))
15031503
other = xr.DataArray(data=array2, dims=("y", "z"))
15041504

1505-
expected = attach_units(
1506-
xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m}
1507-
)
1508-
actual = xr.dot(data_array, other)
1505+
with xr.set_options(use_opt_einsum=False):
1506+
expected = attach_units(
1507+
xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m}
1508+
)
1509+
actual = xr.dot(data_array, other)
15091510

15101511
assert_units_equal(expected, actual)
15111512
assert_identical(expected, actual)
@@ -2465,8 +2466,9 @@ def test_binary_operations(self, func, dtype):
24652466
data_array = xr.DataArray(data=array)
24662467

24672468
units = extract_units(func(array))
2468-
expected = attach_units(func(strip_units(data_array)), units)
2469-
actual = func(data_array)
2469+
with xr.set_options(use_opt_einsum=False):
2470+
expected = attach_units(func(strip_units(data_array)), units)
2471+
actual = func(data_array)
24702472

24712473
assert_units_equal(expected, actual)
24722474
assert_identical(expected, actual)
@@ -3829,8 +3831,9 @@ def test_computation(self, func, variant, dtype):
38293831
if not isinstance(func, (function, method)):
38303832
units.update(extract_units(func(array.reshape(-1))))
38313833

3832-
expected = attach_units(func(strip_units(data_array)), units)
3833-
actual = func(data_array)
3834+
with xr.set_options(use_opt_einsum=False):
3835+
expected = attach_units(func(strip_units(data_array)), units)
3836+
actual = func(data_array)
38343837

38353838
assert_units_equal(expected, actual)
38363839
assert_identical(expected, actual)

0 commit comments

Comments
 (0)