Skip to content

ENH: speed up array_namespace #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 106 additions & 90 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import enum
import inspect
import math
import sys
Expand Down Expand Up @@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None:
)


class _ClsToXPInfo(enum.Enum):
SCALAR = 0
MAYBE_JAX_ZERO_GRADIENT = 1


@lru_cache(100)
def _cls_to_namespace(
cls: type,
api_version: str | None,
use_compat: bool | None,
) -> tuple[Namespace | None, _ClsToXPInfo | None]:
if use_compat not in (None, True, False):
raise ValueError("use_compat must be None, True, or False")
_use_compat = use_compat in (None, True)
cls_ = cast(Hashable, cls) # Make mypy happy

if (
_issubclass_fast(cls_, "numpy", "ndarray")
or _issubclass_fast(cls_, "numpy", "generic")
):
if use_compat is True:
_check_api_version(api_version)
from .. import numpy as xp
elif use_compat is False:
import numpy as xp # type: ignore[no-redef]
else:
# NumPy 2.0+ have __array_namespace__; however they are not
# yet fully array API compatible.
from .. import numpy as xp # type: ignore[no-redef]
return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT

# Note: this must happen _after_ the test for np.generic,
# because np.float64 and np.complex128 are subclasses of float and complex.
if issubclass(cls, int | float | complex | type(None)):
return None, _ClsToXPInfo.SCALAR

if _issubclass_fast(cls_, "cupy", "ndarray"):
if _use_compat:
_check_api_version(api_version)
from .. import cupy as xp # type: ignore[no-redef]
else:
import cupy as xp # type: ignore[no-redef]
return xp, None

if _issubclass_fast(cls_, "torch", "Tensor"):
if _use_compat:
_check_api_version(api_version)
from .. import torch as xp # type: ignore[no-redef]
else:
import torch as xp # type: ignore[no-redef]
return xp, None

if _issubclass_fast(cls_, "dask.array", "Array"):
if _use_compat:
_check_api_version(api_version)
from ..dask import array as xp # type: ignore[no-redef]
else:
import dask.array as xp # type: ignore[no-redef]
return xp, None

# Backwards compatibility for jax<0.4.32
if _issubclass_fast(cls_, "jax", "Array"):
return _jax_namespace(api_version, use_compat), None

return None, None


def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace:
if use_compat:
raise ValueError("JAX does not have an array-api-compat wrapper")
import jax.numpy as jnp
if not hasattr(jnp, "__array_namespace_info__"):
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
# For older JAX versions, it is available via jax.experimental.array_api.
# jnp.Array objects gain the __array_namespace__ method.
import jax.experimental.array_api # noqa: F401
# Test api_version
return jnp.empty(0).__array_namespace__(api_version=api_version)


def array_namespace(
*xs: Array | complex | None,
api_version: str | None = None,
Expand Down Expand Up @@ -553,105 +634,40 @@ def your_function(x, y):
is_pydata_sparse_array

"""
if use_compat not in [None, True, False]:
raise ValueError("use_compat must be None, True, or False")

_use_compat = use_compat in [None, True]

namespaces: set[Namespace] = set()
for x in xs:
if is_numpy_array(x):
import numpy as np

from .. import numpy as numpy_namespace

if use_compat is True:
_check_api_version(api_version)
namespaces.add(numpy_namespace)
elif use_compat is False:
namespaces.add(np)
else:
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
# compatible.
namespaces.add(numpy_namespace)
elif is_cupy_array(x):
if _use_compat:
_check_api_version(api_version)
from .. import cupy as cupy_namespace

namespaces.add(cupy_namespace)
else:
import cupy as cp # pyright: ignore[reportMissingTypeStubs]

namespaces.add(cp)
elif is_torch_array(x):
if _use_compat:
_check_api_version(api_version)
from .. import torch as torch_namespace

namespaces.add(torch_namespace)
else:
import torch

namespaces.add(torch)
elif is_dask_array(x):
if _use_compat:
_check_api_version(api_version)
from ..dask import array as dask_namespace

namespaces.add(dask_namespace)
else:
import dask.array as da

namespaces.add(da)
elif is_jax_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("JAX does not have an array-api-compat wrapper")
elif use_compat is False:
import jax.numpy as jnp
else:
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
# For older JAX versions, it is available via jax.experimental.array_api.
import jax.numpy

if hasattr(jax.numpy, "__array_api_version__"):
jnp = jax.numpy
else:
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports]
namespaces.add(jnp)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
else:
import sparse # pyright: ignore[reportMissingTypeStubs]
# `sparse` is already an array namespace. We do not have a wrapper
# submodule for it.
namespaces.add(sparse)
elif hasattr(x, "__array_namespace__"):
if use_compat is True:
xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat)
if info is _ClsToXPInfo.SCALAR:
continue

if (
info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT
and _is_jax_zero_gradient_array(x)
):
xp = _jax_namespace(api_version, use_compat)

if xp is None:
get_ns = getattr(x, "__array_namespace__", None)
if get_ns is None:
raise TypeError(f"{type(x).__name__} is not a supported array type")
if use_compat:
raise ValueError(
"The given array does not have an array-api-compat wrapper"
)
x = cast("SupportsArrayNamespace[Any]", x)
namespaces.add(x.__array_namespace__(api_version=api_version))
elif isinstance(x, (bool, int, float, complex, type(None))):
continue
else:
# TODO: Support Python scalars?
raise TypeError(f"{type(x).__name__} is not a supported array type")
xp = get_ns(api_version=api_version)

if not namespaces:
raise TypeError("Unrecognized array input")
namespaces.add(xp)

if len(namespaces) != 1:
try:
(xp,) = namespaces
return xp
except ValueError:
if not namespaces:
raise TypeError(
"array_namespace requires at least one non-scalar array input"
)
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")

(xp,) = namespaces

return xp


# backwards compatibility alias
get_namespace = array_namespace
Expand Down
Loading
Loading