Skip to content

Commit 362c48a

Browse files
committed
Type annotations, part 4
1 parent 205c967 commit 362c48a

22 files changed

+269
-396
lines changed

array_api_compat/_internal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
4646
specification for more details.
4747
4848
"""
49-
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
50-
return wrapped_f # pyright: ignore[reportReturnType]
49+
wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
50+
return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]
5151

5252
return inner
5353

array_api_compat/common/_aliases.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from __future__ import annotations
66

77
import inspect
8-
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
8+
from collections.abc import Sequence
9+
from types import NoneType
10+
from typing import TYPE_CHECKING, Any, NamedTuple, cast
911

1012
from ._helpers import _check_device, array_namespace
1113
from ._helpers import device as _get_device
12-
from ._helpers import is_cupy_namespace as _is_cupy_namespace
14+
from ._helpers import is_cupy_namespace
1315
from ._typing import Array, Device, DType, Namespace
1416

1517
if TYPE_CHECKING:
@@ -381,8 +383,8 @@ def clip(
381383
# TODO: np.clip has other ufunc kwargs
382384
out: Array | None = None,
383385
) -> Array:
384-
def _isscalar(a: object) -> TypeIs[int | float | None]:
385-
return isinstance(a, (int, float, type(None)))
386+
def _isscalar(a: object) -> TypeIs[float | None]:
387+
return isinstance(a, int | float | NoneType)
386388

387389
min_shape = () if _isscalar(min) else min.shape
388390
max_shape = () if _isscalar(max) else max.shape
@@ -450,7 +452,7 @@ def reshape(
450452
shape: tuple[int, ...],
451453
xp: Namespace,
452454
*,
453-
copy: Optional[bool] = None,
455+
copy: bool | None = None,
454456
**kwargs: object,
455457
) -> Array:
456458
if copy is True:
@@ -657,7 +659,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
657659
out = xp.sign(x, **kwargs)
658660
# CuPy sign() does not propagate nans. See
659661
# https://github.com/data-apis/array-api-compat/issues/136
660-
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
662+
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
661663
out[xp.isnan(x)] = xp.nan
662664
return out[()]
663665

@@ -720,7 +722,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
720722
"finfo",
721723
"iinfo",
722724
]
723-
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
725+
_all_ignore = ["is_cupy_namespace", "inspect", "array_namespace", "NamedTuple"]
724726

725727

726728
def __dir__() -> list[str]:

array_api_compat/common/_helpers.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,56 +12,51 @@
1212
import math
1313
import sys
1414
import warnings
15-
from collections.abc import Collection
15+
from types import NoneType
1616
from typing import (
1717
TYPE_CHECKING,
1818
Any,
1919
Final,
2020
Literal,
21-
SupportsIndex,
2221
TypeAlias,
2322
TypeGuard,
24-
TypeVar,
2523
cast,
2624
overload,
2725
)
2826

2927
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
3028

3129
if TYPE_CHECKING:
32-
30+
import cupy as cp
3331
import dask.array as da
3432
import jax
3533
import ndonnx as ndx
3634
import numpy as np
3735
import numpy.typing as npt
38-
import sparse # pyright: ignore[reportMissingTypeStubs]
36+
import sparse
3937
import torch
4038

4139
# TODO: import from typing (requires Python >=3.13)
42-
from typing_extensions import TypeIs, TypeVar
43-
44-
_SizeT = TypeVar("_SizeT", bound = int | None)
40+
from typing_extensions import TypeIs
4541

4642
_ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
47-
_CupyArray: TypeAlias = Any # cupy has no py.typed
4843

4944
_ArrayApiObj: TypeAlias = (
5045
npt.NDArray[Any]
46+
| cp.ndarray
5147
| da.Array
5248
| jax.Array
5349
| ndx.Array
5450
| sparse.SparseArray
5551
| torch.Tensor
56-
| SupportsArrayNamespace[Any]
57-
| _CupyArray
52+
| SupportsArrayNamespace
5853
)
5954

6055
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
6156
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
6257

6358

64-
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
59+
def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]:
6560
"""Return True if `x` is a zero-gradient array.
6661
6762
These arrays are a design quirk of Jax that may one day be removed.
@@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
8075
)
8176

8277

83-
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
78+
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
8479
"""
8580
Return True if `x` is a NumPy array.
8681
@@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool:
137132
if "cupy" not in sys.modules:
138133
return False
139134

140-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
135+
import cupy as cp
141136

142137
# TODO: Should we reject ndarray subclasses?
143138
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType]
@@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
280275
if "sparse" not in sys.modules:
281276
return False
282277

283-
import sparse # pyright: ignore[reportMissingTypeStubs]
278+
import sparse
284279

285280
# TODO: Account for other backends.
286281
return isinstance(x, sparse.SparseArray)
287282

288283

289-
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
284+
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
290285
"""
291286
Return True if `x` is an array API compatible array object.
292287
@@ -587,7 +582,7 @@ def your_function(x, y):
587582

588583
namespaces.add(cupy_namespace)
589584
else:
590-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
585+
import cupy as cp
591586

592587
namespaces.add(cp)
593588
elif is_torch_array(x):
@@ -624,14 +619,14 @@ def your_function(x, y):
624619
if hasattr(jax.numpy, "__array_api_version__"):
625620
jnp = jax.numpy
626621
else:
627-
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports]
622+
import jax.experimental.array_api as jnp # type: ignore[no-redef]
628623
namespaces.add(jnp)
629624
elif is_pydata_sparse_array(x):
630625
if use_compat is True:
631626
_check_api_version(api_version)
632627
raise ValueError("`sparse` does not have an array-api-compat wrapper")
633628
else:
634-
import sparse # pyright: ignore[reportMissingTypeStubs]
629+
import sparse
635630
# `sparse` is already an array namespace. We do not have a wrapper
636631
# submodule for it.
637632
namespaces.add(sparse)
@@ -640,9 +635,9 @@ def your_function(x, y):
640635
raise ValueError(
641636
"The given array does not have an array-api-compat wrapper"
642637
)
643-
x = cast("SupportsArrayNamespace[Any]", x)
638+
x = cast(SupportsArrayNamespace, x)
644639
namespaces.add(x.__array_namespace__(api_version=api_version))
645-
elif isinstance(x, (bool, int, float, complex, type(None))):
640+
elif isinstance(x, int | float | complex | NoneType):
646641
continue
647642
else:
648643
# TODO: Support Python scalars?
@@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device:
738733
return "cpu"
739734
elif is_dask_array(x):
740735
# Peek at the metadata of the Dask array to determine type
741-
if is_numpy_array(x._meta): # pyright: ignore
736+
if is_numpy_array(x._meta):
742737
# Must be on CPU since backed by numpy
743738
return "cpu"
744739
return _DASK_DEVICE
@@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device:
767762
return "cpu"
768763
# Return the device of the constituent array
769764
return device(inner) # pyright: ignore
770-
return x.device # pyright: ignore
765+
return x.device # type: ignore # pyright: ignore
771766

772767

773768
# Prevent shadowing, used below
@@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device:
776771

777772
# Based on cupy.array_api.Array.to_device
778773
def _cupy_to_device(
779-
x: _CupyArray,
774+
x: cp.ndarray,
780775
device: Device,
781776
/,
782777
stream: int | Any | None = None,
783-
) -> _CupyArray:
784-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
778+
) -> cp.ndarray:
779+
import cupy as cp
785780
from cupy.cuda import Device as _Device # pyright: ignore
786781
from cupy.cuda import stream as stream_module # pyright: ignore
787782
from cupy_backends.cuda.api import runtime # pyright: ignore
@@ -797,10 +792,10 @@ def _cupy_to_device(
797792
raise ValueError(f"Unsupported device {device!r}")
798793
else:
799794
# see cupy/cupy#5985 for the reason how we handle device/stream here
800-
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
795+
prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
801796
prev_stream = None
802797
if stream is not None:
803-
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
798+
prev_stream = stream_module.get_current_stream() # pyright: ignore
804799
# stream can be an int as specified in __dlpack__, or a CuPy stream
805800
if isinstance(stream, int):
806801
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
@@ -814,7 +809,7 @@ def _cupy_to_device(
814809
arr = x.copy()
815810
finally:
816811
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
817-
if stream is not None:
812+
if prev_stream is not None:
818813
prev_stream.use()
819814
return arr
820815

@@ -823,7 +818,7 @@ def _torch_to_device(
823818
x: torch.Tensor,
824819
device: torch.device | str | int,
825820
/,
826-
stream: None = None,
821+
stream: int | Any | None = None,
827822
) -> torch.Tensor:
828823
if stream is not None:
829824
raise NotImplementedError
@@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
889884
# cupy does not yet have to_device
890885
return _cupy_to_device(x, device, stream=stream)
891886
elif is_torch_array(x):
892-
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType]
887+
return _torch_to_device(x, device, stream=stream)
893888
elif is_dask_array(x):
894889
if stream is not None:
895890
raise ValueError("The stream argument to to_device() is not supported")
@@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
914909

915910

916911
@overload
917-
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
912+
def size(x: HasShape[int]) -> int: ...
918913
@overload
919-
def size(x: HasShape[Collection[None]]) -> None: ...
914+
def size(x: HasShape[int | None]) -> int | None: ...
920915
@overload
921-
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
922-
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
916+
def size(x: HasShape[float]) -> int | None: ... # Dask special case
917+
def size(x: HasShape[float | None]) -> int | None:
923918
"""
924919
Return the total number of elements of x.
925920
@@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
934929
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
935930
if None in x.shape:
936931
return None
937-
out = math.prod(cast("Collection[SupportsIndex]", x.shape))
932+
out = math.prod(cast(tuple[float, ...], x.shape))
938933
# dask.array.Array.shape can contain NaN
939-
return None if math.isnan(out) else out
934+
return None if math.isnan(out) else cast(int, out)
940935

941936

942-
def is_writeable_array(x: object) -> bool:
937+
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
943938
"""
944939
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
945940
Return False if `x` is not an array API compatible object.
@@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool:
956951
return is_array_api_obj(x)
957952

958953

959-
def is_lazy_array(x: object) -> bool:
954+
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
960955
"""Return True if x is potentially a future or it may be otherwise impossible or
961956
expensive to eagerly read its contents, regardless of their size, e.g. by
962957
calling ``bool(x)`` or ``float(x)``.
@@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool:
997992
# on __bool__ (dask is one such example, which however is special-cased above).
998993

999994
# Select a single point of the array
1000-
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
995+
s = size(cast(HasShape, x))
1001996
if s is None:
1002997
return True
1003998
xp = array_namespace(x)
@@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool:
10441039

10451040
_all_ignore = ["sys", "math", "inspect", "warnings"]
10461041

1042+
10471043
def __dir__() -> list[str]:
10481044
return __all__

array_api_compat/common/_linalg.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
if np.__version__[0] == "2":
99
from numpy.lib.array_utils import normalize_axis_tuple
1010
else:
11-
from numpy.core.numeric import normalize_axis_tuple
11+
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
1212

1313
from .._internal import get_xp
1414
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
@@ -164,7 +164,7 @@ def vector_norm(
164164
if axis is None:
165165
# Note: xp.linalg.norm() doesn't handle 0-D arrays
166166
_x = x.ravel()
167-
_axis = 0
167+
axis = 0
168168
elif isinstance(axis, tuple):
169169
# Note: The axis argument supports any number of axes, whereas
170170
# xp.linalg.norm() only supports a single axis for vector norm.
@@ -176,25 +176,24 @@ def vector_norm(
176176
newshape = axis + rest
177177
_x = xp.transpose(x, newshape).reshape(
178178
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
179-
_axis = 0
179+
axis = 0
180180
else:
181181
_x = x
182-
_axis = axis
183182

184-
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
183+
res = xp.linalg.norm(_x, axis=axis, ord=ord)
185184

186185
if keepdims:
187186
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
188187
# above to avoid matrix norm logic.
189188
shape = list(x.shape)
190-
_axis = cast(
189+
axis = cast(
191190
"tuple[int, ...]",
192191
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
193192
range(x.ndim) if axis is None else axis,
194193
x.ndim,
195194
),
196195
)
197-
for i in _axis:
196+
for i in axis:
198197
shape[i] = 1
199198
res = xp.reshape(res, tuple(shape))
200199

0 commit comments

Comments
 (0)