|
8 | 8 |
|
9 | 9 | from __future__ import annotations
|
10 | 10 |
|
| 11 | +import enum |
11 | 12 | import inspect
|
12 | 13 | import math
|
13 | 14 | import sys
|
@@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None:
|
485 | 486 | )
|
486 | 487 |
|
487 | 488 |
|
| 489 | +class _ClsToXPInfo(enum.Enum): |
| 490 | + SCALAR = 0 |
| 491 | + MAYBE_JAX_ZERO_GRADIENT = 1 |
| 492 | + |
| 493 | + |
| 494 | +@lru_cache(100) |
| 495 | +def _cls_to_namespace( |
| 496 | + cls: type, |
| 497 | + api_version: str | None, |
| 498 | + use_compat: bool | None, |
| 499 | +) -> tuple[Namespace | None, _ClsToXPInfo | None]: |
| 500 | + if use_compat not in (None, True, False): |
| 501 | + raise ValueError("use_compat must be None, True, or False") |
| 502 | + _use_compat = use_compat in (None, True) |
| 503 | + cls_ = cast(Hashable, cls) # Make mypy happy |
| 504 | + |
| 505 | + if ( |
| 506 | + _issubclass_fast(cls_, "numpy", "ndarray") |
| 507 | + or _issubclass_fast(cls_, "numpy", "generic") |
| 508 | + ): |
| 509 | + if use_compat is True: |
| 510 | + _check_api_version(api_version) |
| 511 | + from .. import numpy as xp |
| 512 | + elif use_compat is False: |
| 513 | + import numpy as xp # type: ignore[no-redef] |
| 514 | + else: |
| 515 | + # NumPy 2.0+ have __array_namespace__; however they are not |
| 516 | + # yet fully array API compatible. |
| 517 | + from .. import numpy as xp # type: ignore[no-redef] |
| 518 | + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 519 | + |
| 520 | + # Note: this must happen _after_ the test for np.generic, |
| 521 | + # because np.float64 and np.complex128 are subclasses of float and complex. |
| 522 | + if issubclass(cls, int | float | complex | type(None)): |
| 523 | + return None, _ClsToXPInfo.SCALAR |
| 524 | + |
| 525 | + if _issubclass_fast(cls_, "cupy", "ndarray"): |
| 526 | + if _use_compat: |
| 527 | + _check_api_version(api_version) |
| 528 | + from .. import cupy as xp # type: ignore[no-redef] |
| 529 | + else: |
| 530 | + import cupy as xp # type: ignore[no-redef] |
| 531 | + return xp, None |
| 532 | + |
| 533 | + if _issubclass_fast(cls_, "torch", "Tensor"): |
| 534 | + if _use_compat: |
| 535 | + _check_api_version(api_version) |
| 536 | + from .. import torch as xp # type: ignore[no-redef] |
| 537 | + else: |
| 538 | + import torch as xp # type: ignore[no-redef] |
| 539 | + return xp, None |
| 540 | + |
| 541 | + if _issubclass_fast(cls_, "dask.array", "Array"): |
| 542 | + if _use_compat: |
| 543 | + _check_api_version(api_version) |
| 544 | + from ..dask import array as xp # type: ignore[no-redef] |
| 545 | + else: |
| 546 | + import dask.array as xp # type: ignore[no-redef] |
| 547 | + return xp, None |
| 548 | + |
| 549 | + # Backwards compatibility for jax<0.4.32 |
| 550 | + if _issubclass_fast(cls_, "jax", "Array"): |
| 551 | + return _jax_namespace(api_version, use_compat), None |
| 552 | + |
| 553 | + return None, None |
| 554 | + |
| 555 | + |
| 556 | +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: |
| 557 | + if use_compat: |
| 558 | + raise ValueError("JAX does not have an array-api-compat wrapper") |
| 559 | + import jax.numpy as jnp |
| 560 | + if not hasattr(jnp, "__array_namespace_info__"): |
| 561 | + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
| 562 | + # For older JAX versions, it is available via jax.experimental.array_api. |
| 563 | + # jnp.Array objects gain the __array_namespace__ method. |
| 564 | + import jax.experimental.array_api # noqa: F401 |
| 565 | + # Test api_version |
| 566 | + return jnp.empty(0).__array_namespace__(api_version=api_version) |
| 567 | + |
| 568 | + |
488 | 569 | def array_namespace(
|
489 | 570 | *xs: Array | complex | None,
|
490 | 571 | api_version: str | None = None,
|
@@ -553,105 +634,40 @@ def your_function(x, y):
|
553 | 634 | is_pydata_sparse_array
|
554 | 635 |
|
555 | 636 | """
|
556 |
| - if use_compat not in [None, True, False]: |
557 |
| - raise ValueError("use_compat must be None, True, or False") |
558 |
| - |
559 |
| - _use_compat = use_compat in [None, True] |
560 |
| - |
561 | 637 | namespaces: set[Namespace] = set()
|
562 | 638 | for x in xs:
|
563 |
| - if is_numpy_array(x): |
564 |
| - import numpy as np |
565 |
| - |
566 |
| - from .. import numpy as numpy_namespace |
567 |
| - |
568 |
| - if use_compat is True: |
569 |
| - _check_api_version(api_version) |
570 |
| - namespaces.add(numpy_namespace) |
571 |
| - elif use_compat is False: |
572 |
| - namespaces.add(np) |
573 |
| - else: |
574 |
| - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API |
575 |
| - # compatible. |
576 |
| - namespaces.add(numpy_namespace) |
577 |
| - elif is_cupy_array(x): |
578 |
| - if _use_compat: |
579 |
| - _check_api_version(api_version) |
580 |
| - from .. import cupy as cupy_namespace |
581 |
| - |
582 |
| - namespaces.add(cupy_namespace) |
583 |
| - else: |
584 |
| - import cupy as cp # pyright: ignore[reportMissingTypeStubs] |
585 |
| - |
586 |
| - namespaces.add(cp) |
587 |
| - elif is_torch_array(x): |
588 |
| - if _use_compat: |
589 |
| - _check_api_version(api_version) |
590 |
| - from .. import torch as torch_namespace |
591 |
| - |
592 |
| - namespaces.add(torch_namespace) |
593 |
| - else: |
594 |
| - import torch |
595 |
| - |
596 |
| - namespaces.add(torch) |
597 |
| - elif is_dask_array(x): |
598 |
| - if _use_compat: |
599 |
| - _check_api_version(api_version) |
600 |
| - from ..dask import array as dask_namespace |
601 |
| - |
602 |
| - namespaces.add(dask_namespace) |
603 |
| - else: |
604 |
| - import dask.array as da |
605 |
| - |
606 |
| - namespaces.add(da) |
607 |
| - elif is_jax_array(x): |
608 |
| - if use_compat is True: |
609 |
| - _check_api_version(api_version) |
610 |
| - raise ValueError("JAX does not have an array-api-compat wrapper") |
611 |
| - elif use_compat is False: |
612 |
| - import jax.numpy as jnp |
613 |
| - else: |
614 |
| - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. |
615 |
| - # For older JAX versions, it is available via jax.experimental.array_api. |
616 |
| - import jax.numpy |
617 |
| - |
618 |
| - if hasattr(jax.numpy, "__array_api_version__"): |
619 |
| - jnp = jax.numpy |
620 |
| - else: |
621 |
| - import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] |
622 |
| - namespaces.add(jnp) |
623 |
| - elif is_pydata_sparse_array(x): |
624 |
| - if use_compat is True: |
625 |
| - _check_api_version(api_version) |
626 |
| - raise ValueError("`sparse` does not have an array-api-compat wrapper") |
627 |
| - else: |
628 |
| - import sparse # pyright: ignore[reportMissingTypeStubs] |
629 |
| - # `sparse` is already an array namespace. We do not have a wrapper |
630 |
| - # submodule for it. |
631 |
| - namespaces.add(sparse) |
632 |
| - elif hasattr(x, "__array_namespace__"): |
633 |
| - if use_compat is True: |
| 639 | + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) |
| 640 | + if info is _ClsToXPInfo.SCALAR: |
| 641 | + continue |
| 642 | + |
| 643 | + if ( |
| 644 | + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT |
| 645 | + and _is_jax_zero_gradient_array(x) |
| 646 | + ): |
| 647 | + xp = _jax_namespace(api_version, use_compat) |
| 648 | + |
| 649 | + if xp is None: |
| 650 | + get_ns = getattr(x, "__array_namespace__", None) |
| 651 | + if get_ns is None: |
| 652 | + raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 653 | + if use_compat: |
634 | 654 | raise ValueError(
|
635 | 655 | "The given array does not have an array-api-compat wrapper"
|
636 | 656 | )
|
637 |
| - x = cast("SupportsArrayNamespace[Any]", x) |
638 |
| - namespaces.add(x.__array_namespace__(api_version=api_version)) |
639 |
| - elif isinstance(x, (bool, int, float, complex, type(None))): |
640 |
| - continue |
641 |
| - else: |
642 |
| - # TODO: Support Python scalars? |
643 |
| - raise TypeError(f"{type(x).__name__} is not a supported array type") |
| 657 | + xp = get_ns(api_version=api_version) |
644 | 658 |
|
645 |
| - if not namespaces: |
646 |
| - raise TypeError("Unrecognized array input") |
| 659 | + namespaces.add(xp) |
647 | 660 |
|
648 |
| - if len(namespaces) != 1: |
| 661 | + try: |
| 662 | + (xp,) = namespaces |
| 663 | + return xp |
| 664 | + except ValueError: |
| 665 | + if not namespaces: |
| 666 | + raise TypeError( |
| 667 | + "array_namespace requires at least one non-scalar array input" |
| 668 | + ) |
649 | 669 | raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
|
650 | 670 |
|
651 |
| - (xp,) = namespaces |
652 |
| - |
653 |
| - return xp |
654 |
| - |
655 | 671 |
|
656 | 672 | # backwards compatibility alias
|
657 | 673 | get_namespace = array_namespace
|
|
0 commit comments