12
12
import math
13
13
import sys
14
14
import warnings
15
- from collections . abc import Collection
15
+ from types import NoneType
16
16
from typing import (
17
17
TYPE_CHECKING ,
18
18
Any ,
19
19
Final ,
20
20
Literal ,
21
- SupportsIndex ,
22
21
TypeAlias ,
23
22
TypeGuard ,
24
- TypeVar ,
25
23
cast ,
26
24
overload ,
27
25
)
28
26
29
27
from ._typing import Array , Device , HasShape , Namespace , SupportsArrayNamespace
30
28
31
29
if TYPE_CHECKING :
32
-
30
+ import cupy as cp
33
31
import dask .array as da
34
32
import jax
35
33
import ndonnx as ndx
36
34
import numpy as np
37
35
import numpy .typing as npt
38
- import sparse # pyright: ignore[reportMissingTypeStubs]
36
+ import sparse
39
37
import torch
40
38
41
39
# TODO: import from typing (requires Python >=3.13)
42
- from typing_extensions import TypeIs , TypeVar
43
-
44
- _SizeT = TypeVar ("_SizeT" , bound = int | None )
40
+ from typing_extensions import TypeIs
45
41
46
42
_ZeroGradientArray : TypeAlias = npt .NDArray [np .void ]
47
- _CupyArray : TypeAlias = Any # cupy has no py.typed
48
43
49
44
_ArrayApiObj : TypeAlias = (
50
45
npt .NDArray [Any ]
46
+ | cp .ndarray
51
47
| da .Array
52
48
| jax .Array
53
49
| ndx .Array
54
50
| sparse .SparseArray
55
51
| torch .Tensor
56
- | SupportsArrayNamespace [Any ]
57
- | _CupyArray
52
+ | SupportsArrayNamespace
58
53
)
59
54
60
55
_API_VERSIONS_OLD : Final = frozenset ({"2021.12" , "2022.12" , "2023.12" })
61
56
_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
62
57
63
58
64
- def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
59
+ def _is_jax_zero_gradient_array (x : object ) -> TypeIs [_ZeroGradientArray ]:
65
60
"""Return True if `x` is a zero-gradient array.
66
61
67
62
These arrays are a design quirk of Jax that may one day be removed.
@@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
80
75
)
81
76
82
77
83
- def is_numpy_array (x : object ) -> TypeGuard [npt .NDArray [Any ]]:
78
+ def is_numpy_array (x : object ) -> TypeIs [npt .NDArray [Any ]]:
84
79
"""
85
80
Return True if `x` is a NumPy array.
86
81
@@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool:
137
132
if "cupy" not in sys .modules :
138
133
return False
139
134
140
- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
135
+ import cupy as cp
141
136
142
137
# TODO: Should we reject ndarray subclasses?
143
138
return isinstance (x , cp .ndarray ) # pyright: ignore[reportUnknownMemberType]
@@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
280
275
if "sparse" not in sys .modules :
281
276
return False
282
277
283
- import sparse # pyright: ignore[reportMissingTypeStubs]
278
+ import sparse
284
279
285
280
# TODO: Account for other backends.
286
281
return isinstance (x , sparse .SparseArray )
287
282
288
283
289
- def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
284
+ def is_array_api_obj (x : object ) -> TypeGuard [_ArrayApiObj ]:
290
285
"""
291
286
Return True if `x` is an array API compatible array object.
292
287
@@ -587,7 +582,7 @@ def your_function(x, y):
587
582
588
583
namespaces .add (cupy_namespace )
589
584
else :
590
- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
585
+ import cupy as cp
591
586
592
587
namespaces .add (cp )
593
588
elif is_torch_array (x ):
@@ -624,14 +619,14 @@ def your_function(x, y):
624
619
if hasattr (jax .numpy , "__array_api_version__" ):
625
620
jnp = jax .numpy
626
621
else :
627
- import jax .experimental .array_api as jnp # pyright : ignore[reportMissingImports ]
622
+ import jax .experimental .array_api as jnp # type : ignore[no-redef ]
628
623
namespaces .add (jnp )
629
624
elif is_pydata_sparse_array (x ):
630
625
if use_compat is True :
631
626
_check_api_version (api_version )
632
627
raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
633
628
else :
634
- import sparse # pyright: ignore[reportMissingTypeStubs]
629
+ import sparse
635
630
# `sparse` is already an array namespace. We do not have a wrapper
636
631
# submodule for it.
637
632
namespaces .add (sparse )
@@ -640,9 +635,9 @@ def your_function(x, y):
640
635
raise ValueError (
641
636
"The given array does not have an array-api-compat wrapper"
642
637
)
643
- x = cast (" SupportsArrayNamespace[Any]" , x )
638
+ x = cast (SupportsArrayNamespace , x )
644
639
namespaces .add (x .__array_namespace__ (api_version = api_version ))
645
- elif isinstance (x , ( bool , int , float , complex , type ( None )) ):
640
+ elif isinstance (x , int | float | complex | NoneType ):
646
641
continue
647
642
else :
648
643
# TODO: Support Python scalars?
@@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device:
738
733
return "cpu"
739
734
elif is_dask_array (x ):
740
735
# Peek at the metadata of the Dask array to determine type
741
- if is_numpy_array (x ._meta ): # pyright: ignore
736
+ if is_numpy_array (x ._meta ):
742
737
# Must be on CPU since backed by numpy
743
738
return "cpu"
744
739
return _DASK_DEVICE
@@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device:
767
762
return "cpu"
768
763
# Return the device of the constituent array
769
764
return device (inner ) # pyright: ignore
770
- return x .device # pyright: ignore
765
+ return x .device # type: ignore # pyright: ignore
771
766
772
767
773
768
# Prevent shadowing, used below
@@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device:
776
771
777
772
# Based on cupy.array_api.Array.to_device
778
773
def _cupy_to_device (
779
- x : _CupyArray ,
774
+ x : cp . ndarray ,
780
775
device : Device ,
781
776
/ ,
782
777
stream : int | Any | None = None ,
783
- ) -> _CupyArray :
784
- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
778
+ ) -> cp . ndarray :
779
+ import cupy as cp
785
780
from cupy .cuda import Device as _Device # pyright: ignore
786
781
from cupy .cuda import stream as stream_module # pyright: ignore
787
782
from cupy_backends .cuda .api import runtime # pyright: ignore
@@ -797,10 +792,10 @@ def _cupy_to_device(
797
792
raise ValueError (f"Unsupported device { device !r} " )
798
793
else :
799
794
# see cupy/cupy#5985 for the reason how we handle device/stream here
800
- prev_device : Any = runtime .getDevice () # pyright: ignore[reportUnknownMemberType]
795
+ prev_device : Device = runtime .getDevice () # pyright: ignore[reportUnknownMemberType]
801
796
prev_stream = None
802
797
if stream is not None :
803
- prev_stream : Any = stream_module .get_current_stream () # pyright: ignore
798
+ prev_stream = stream_module .get_current_stream () # pyright: ignore
804
799
# stream can be an int as specified in __dlpack__, or a CuPy stream
805
800
if isinstance (stream , int ):
806
801
stream = cp .cuda .ExternalStream (stream ) # pyright: ignore
@@ -814,7 +809,7 @@ def _cupy_to_device(
814
809
arr = x .copy ()
815
810
finally :
816
811
runtime .setDevice (prev_device ) # pyright: ignore[reportUnknownMemberType]
817
- if stream is not None :
812
+ if prev_stream is not None :
818
813
prev_stream .use ()
819
814
return arr
820
815
@@ -823,7 +818,7 @@ def _torch_to_device(
823
818
x : torch .Tensor ,
824
819
device : torch .device | str | int ,
825
820
/ ,
826
- stream : None = None ,
821
+ stream : int | Any | None = None ,
827
822
) -> torch .Tensor :
828
823
if stream is not None :
829
824
raise NotImplementedError
@@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
889
884
# cupy does not yet have to_device
890
885
return _cupy_to_device (x , device , stream = stream )
891
886
elif is_torch_array (x ):
892
- return _torch_to_device (x , device , stream = stream ) # pyright: ignore[reportArgumentType]
887
+ return _torch_to_device (x , device , stream = stream )
893
888
elif is_dask_array (x ):
894
889
if stream is not None :
895
890
raise ValueError ("The stream argument to to_device() is not supported" )
@@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
914
909
915
910
916
911
@overload
917
- def size (x : HasShape [Collection [ SupportsIndex ] ]) -> int : ...
912
+ def size (x : HasShape [int ]) -> int : ...
918
913
@overload
919
- def size (x : HasShape [Collection [ None ]] ) -> None : ...
914
+ def size (x : HasShape [int | None ]) -> int | None : ...
920
915
@overload
921
- def size (x : HasShape [Collection [ SupportsIndex | None ]] ) -> int | None : ...
922
- def size (x : HasShape [Collection [ SupportsIndex | None ] ]) -> int | None :
916
+ def size (x : HasShape [float ] ) -> int | None : ... # Dask special case
917
+ def size (x : HasShape [float | None ]) -> int | None :
923
918
"""
924
919
Return the total number of elements of x.
925
920
@@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
934
929
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
935
930
if None in x .shape :
936
931
return None
937
- out = math .prod (cast ("Collection[SupportsIndex]" , x .shape ))
932
+ out = math .prod (cast (tuple [ float , ...] , x .shape ))
938
933
# dask.array.Array.shape can contain NaN
939
- return None if math .isnan (out ) else out
934
+ return None if math .isnan (out ) else cast ( int , out )
940
935
941
936
942
- def is_writeable_array (x : object ) -> bool :
937
+ def is_writeable_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
943
938
"""
944
939
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
945
940
Return False if `x` is not an array API compatible object.
@@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool:
956
951
return is_array_api_obj (x )
957
952
958
953
959
- def is_lazy_array (x : object ) -> bool :
954
+ def is_lazy_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
960
955
"""Return True if x is potentially a future or it may be otherwise impossible or
961
956
expensive to eagerly read its contents, regardless of their size, e.g. by
962
957
calling ``bool(x)`` or ``float(x)``.
@@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool:
997
992
# on __bool__ (dask is one such example, which however is special-cased above).
998
993
999
994
# Select a single point of the array
1000
- s = size (cast (" HasShape[Collection[SupportsIndex | None]]" , x ))
995
+ s = size (cast (HasShape , x ))
1001
996
if s is None :
1002
997
return True
1003
998
xp = array_namespace (x )
@@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool:
1044
1039
1045
1040
_all_ignore = ["sys" , "math" , "inspect" , "warnings" ]
1046
1041
1042
+
1047
1043
def __dir__ () -> list [str ]:
1048
1044
return __all__
0 commit comments