Skip to content

Commit 6ae28ee

Browse files
authored
ENH: speed up array_namespace
* ENH: speed up `array_namespace` * jax 0.6.1 Reviewed at #329
1 parent f7bd970 commit 6ae28ee

File tree

2 files changed

+161
-156
lines changed

2 files changed

+161
-156
lines changed

array_api_compat/common/_helpers.py

Lines changed: 106 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
import enum
1112
import inspect
1213
import math
1314
import sys
@@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None:
485486
)
486487

487488

489+
class _ClsToXPInfo(enum.Enum):
490+
SCALAR = 0
491+
MAYBE_JAX_ZERO_GRADIENT = 1
492+
493+
494+
@lru_cache(100)
495+
def _cls_to_namespace(
496+
cls: type,
497+
api_version: str | None,
498+
use_compat: bool | None,
499+
) -> tuple[Namespace | None, _ClsToXPInfo | None]:
500+
if use_compat not in (None, True, False):
501+
raise ValueError("use_compat must be None, True, or False")
502+
_use_compat = use_compat in (None, True)
503+
cls_ = cast(Hashable, cls) # Make mypy happy
504+
505+
if (
506+
_issubclass_fast(cls_, "numpy", "ndarray")
507+
or _issubclass_fast(cls_, "numpy", "generic")
508+
):
509+
if use_compat is True:
510+
_check_api_version(api_version)
511+
from .. import numpy as xp
512+
elif use_compat is False:
513+
import numpy as xp # type: ignore[no-redef]
514+
else:
515+
# NumPy 2.0+ have __array_namespace__; however they are not
516+
# yet fully array API compatible.
517+
from .. import numpy as xp # type: ignore[no-redef]
518+
return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT
519+
520+
# Note: this must happen _after_ the test for np.generic,
521+
# because np.float64 and np.complex128 are subclasses of float and complex.
522+
if issubclass(cls, int | float | complex | type(None)):
523+
return None, _ClsToXPInfo.SCALAR
524+
525+
if _issubclass_fast(cls_, "cupy", "ndarray"):
526+
if _use_compat:
527+
_check_api_version(api_version)
528+
from .. import cupy as xp # type: ignore[no-redef]
529+
else:
530+
import cupy as xp # type: ignore[no-redef]
531+
return xp, None
532+
533+
if _issubclass_fast(cls_, "torch", "Tensor"):
534+
if _use_compat:
535+
_check_api_version(api_version)
536+
from .. import torch as xp # type: ignore[no-redef]
537+
else:
538+
import torch as xp # type: ignore[no-redef]
539+
return xp, None
540+
541+
if _issubclass_fast(cls_, "dask.array", "Array"):
542+
if _use_compat:
543+
_check_api_version(api_version)
544+
from ..dask import array as xp # type: ignore[no-redef]
545+
else:
546+
import dask.array as xp # type: ignore[no-redef]
547+
return xp, None
548+
549+
# Backwards compatibility for jax<0.4.32
550+
if _issubclass_fast(cls_, "jax", "Array"):
551+
return _jax_namespace(api_version, use_compat), None
552+
553+
return None, None
554+
555+
556+
def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace:
557+
if use_compat:
558+
raise ValueError("JAX does not have an array-api-compat wrapper")
559+
import jax.numpy as jnp
560+
if not hasattr(jnp, "__array_namespace_info__"):
561+
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
562+
# For older JAX versions, it is available via jax.experimental.array_api.
563+
# jnp.Array objects gain the __array_namespace__ method.
564+
import jax.experimental.array_api # noqa: F401
565+
# Test api_version
566+
return jnp.empty(0).__array_namespace__(api_version=api_version)
567+
568+
488569
def array_namespace(
489570
*xs: Array | complex | None,
490571
api_version: str | None = None,
@@ -553,105 +634,40 @@ def your_function(x, y):
553634
is_pydata_sparse_array
554635
555636
"""
556-
if use_compat not in [None, True, False]:
557-
raise ValueError("use_compat must be None, True, or False")
558-
559-
_use_compat = use_compat in [None, True]
560-
561637
namespaces: set[Namespace] = set()
562638
for x in xs:
563-
if is_numpy_array(x):
564-
import numpy as np
565-
566-
from .. import numpy as numpy_namespace
567-
568-
if use_compat is True:
569-
_check_api_version(api_version)
570-
namespaces.add(numpy_namespace)
571-
elif use_compat is False:
572-
namespaces.add(np)
573-
else:
574-
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
575-
# compatible.
576-
namespaces.add(numpy_namespace)
577-
elif is_cupy_array(x):
578-
if _use_compat:
579-
_check_api_version(api_version)
580-
from .. import cupy as cupy_namespace
581-
582-
namespaces.add(cupy_namespace)
583-
else:
584-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
585-
586-
namespaces.add(cp)
587-
elif is_torch_array(x):
588-
if _use_compat:
589-
_check_api_version(api_version)
590-
from .. import torch as torch_namespace
591-
592-
namespaces.add(torch_namespace)
593-
else:
594-
import torch
595-
596-
namespaces.add(torch)
597-
elif is_dask_array(x):
598-
if _use_compat:
599-
_check_api_version(api_version)
600-
from ..dask import array as dask_namespace
601-
602-
namespaces.add(dask_namespace)
603-
else:
604-
import dask.array as da
605-
606-
namespaces.add(da)
607-
elif is_jax_array(x):
608-
if use_compat is True:
609-
_check_api_version(api_version)
610-
raise ValueError("JAX does not have an array-api-compat wrapper")
611-
elif use_compat is False:
612-
import jax.numpy as jnp
613-
else:
614-
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
615-
# For older JAX versions, it is available via jax.experimental.array_api.
616-
import jax.numpy
617-
618-
if hasattr(jax.numpy, "__array_api_version__"):
619-
jnp = jax.numpy
620-
else:
621-
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports]
622-
namespaces.add(jnp)
623-
elif is_pydata_sparse_array(x):
624-
if use_compat is True:
625-
_check_api_version(api_version)
626-
raise ValueError("`sparse` does not have an array-api-compat wrapper")
627-
else:
628-
import sparse # pyright: ignore[reportMissingTypeStubs]
629-
# `sparse` is already an array namespace. We do not have a wrapper
630-
# submodule for it.
631-
namespaces.add(sparse)
632-
elif hasattr(x, "__array_namespace__"):
633-
if use_compat is True:
639+
xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat)
640+
if info is _ClsToXPInfo.SCALAR:
641+
continue
642+
643+
if (
644+
info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT
645+
and _is_jax_zero_gradient_array(x)
646+
):
647+
xp = _jax_namespace(api_version, use_compat)
648+
649+
if xp is None:
650+
get_ns = getattr(x, "__array_namespace__", None)
651+
if get_ns is None:
652+
raise TypeError(f"{type(x).__name__} is not a supported array type")
653+
if use_compat:
634654
raise ValueError(
635655
"The given array does not have an array-api-compat wrapper"
636656
)
637-
x = cast("SupportsArrayNamespace[Any]", x)
638-
namespaces.add(x.__array_namespace__(api_version=api_version))
639-
elif isinstance(x, (bool, int, float, complex, type(None))):
640-
continue
641-
else:
642-
# TODO: Support Python scalars?
643-
raise TypeError(f"{type(x).__name__} is not a supported array type")
657+
xp = get_ns(api_version=api_version)
644658

645-
if not namespaces:
646-
raise TypeError("Unrecognized array input")
659+
namespaces.add(xp)
647660

648-
if len(namespaces) != 1:
661+
try:
662+
(xp,) = namespaces
663+
return xp
664+
except ValueError:
665+
if not namespaces:
666+
raise TypeError(
667+
"array_namespace requires at least one non-scalar array input"
668+
)
649669
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
650670

651-
(xp,) = namespaces
652-
653-
return xp
654-
655671

656672
# backwards compatibility alias
657673
get_namespace = array_namespace

0 commit comments

Comments
 (0)