From e1d8db4a6b78ce0fd50e400f299882891a8ecb6e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 17 Mar 2025 12:17:30 +0000 Subject: [PATCH 1/3] ENH: Fully annotate Array --- pixi.lock | 2 +- pyproject.toml | 2 +- src/array_api_extra/_lib/_at.py | 59 ++++++----- src/array_api_extra/_lib/_funcs.py | 59 +++++------ src/array_api_extra/_lib/_testing.py | 28 ++++-- src/array_api_extra/_lib/_utils/_compat.pyi | 16 +-- src/array_api_extra/_lib/_utils/_helpers.py | 43 ++++++-- src/array_api_extra/_lib/_utils/_typing.py | 27 +++-- src/array_api_extra/_lib/_utils/_typing.pyi | 105 ++++++++++++++++++++ src/array_api_extra/testing.py | 2 +- tests/conftest.py | 2 +- tests/test_at.py | 47 ++++----- tests/test_funcs.py | 30 +++--- tests/test_helpers.py | 26 ++++- tests/test_testing.py | 51 +++++----- 15 files changed, 335 insertions(+), 164 deletions(-) create mode 100644 src/array_api_extra/_lib/_utils/_typing.pyi diff --git a/pixi.lock b/pixi.lock index 7310e5af..c2071402 100644 --- a/pixi.lock +++ b/pixi.lock @@ -3788,7 +3788,7 @@ packages: - pypi: . name: array-api-extra version: 0.7.0.dev0 - sha256: 88f998278ea7742857d385d2171ce91fe8ffde2d36416810070e15d523f5d0bf + sha256: af349b53edfb4298b00cbb25c5e3d68fa41ae6abcca3d0a7032f4423fe8bcd14 requires_dist: - array-api-compat>=1.11,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index 63682d9d..78ab718f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,7 +203,7 @@ enable_error_code = ["ignore-without-code", "truthy-bool"] # https://github.com/data-apis/array-api-typing disallow_any_expr = false # false positives with input validation -disable_error_code = ["redundant-expr", "unreachable"] +disable_error_code = ["redundant-expr", "unreachable", "no-any-return"] [[tool.mypy.overrides]] # slow/unavailable on Windows; do not add to the lint env diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index a008efa9..11c8ff6c 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -15,7 +15,7 @@ is_jax_array, is_writeable_array, ) -from ._utils._typing import Array, Index +from ._utils._typing import Array, SetIndex class _AtOp(Enum): @@ -43,7 +43,13 @@ def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[ return self.value -_undef = object() +class Undef(Enum): + """Sentinel for undefined values.""" + + UNDEF = 0 + + +_undef = Undef.UNDEF class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 @@ -188,16 +194,16 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02 """ _x: Array - _idx: Index + _idx: SetIndex | Undef __slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x") def __init__( - self, x: Array, idx: Index = _undef, / + self, x: Array, idx: SetIndex | Undef = _undef, / ) -> None: # numpydoc ignore=GL08 self._x = x self._idx = idx - def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01 + def __getitem__(self, idx: SetIndex, /) -> at: # numpydoc ignore=PR01,RT01 """ Allow for the alternate syntax ``at(x)[start:stop:step]``. @@ -212,9 +218,9 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01 def _op( self, at_op: _AtOp, - in_place_op: Callable[[Array, Array | object], Array] | None, + in_place_op: Callable[[Array, Array | complex], Array] | None, out_of_place_op: Callable[[Array, Array], Array] | None, - y: Array | object, + y: Array | complex, /, copy: bool | None, xp: ModuleType | None, @@ -226,7 +232,7 @@ def _op( ---------- at_op : _AtOp Method of JAX's Array.at[]. - in_place_op : Callable[[Array, Array | object], Array] | None + in_place_op : Callable[[Array, Array | complex], Array] | None In-place operation to apply on mutable backends:: x[idx] = in_place_op(x[idx], y) @@ -245,7 +251,7 @@ def _op( x = xp.where(idx, y, x) - y : array or object + y : array or complex Right-hand side of the operation. copy : bool or None Whether to copy the input array. See the class docstring for details. @@ -260,7 +266,7 @@ def _op( x, idx = self._x, self._idx xp = array_namespace(x, y) if xp is None else xp - if idx is _undef: + if isinstance(idx, Undef): msg = ( "Index has not been set.\n" "Usage: either\n" @@ -306,7 +312,10 @@ def _op( if copy or (copy is None and not writeable): if is_jax_array(x): # Use JAX's at[] - func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value)) + func = cast( + Callable[[Array | complex], Array], + getattr(x.at[idx], at_op.value), # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue,reportUnknownArgumentType] + ) out = func(y) # Undo int->float promotion on JAX after _AtOp.DIVIDE return xp.astype(out, x.dtype, copy=False) @@ -315,10 +324,10 @@ def _op( # with a copy followed by an update x = xp.asarray(x, copy=True) - if writeable is False: - # A copy of a read-only numpy array is writeable - # Note: this assumes that a copy of a writeable array is writeable - writeable = None + # A copy of a read-only numpy array is writeable + # Note: this assumes that a copy of a writeable array is writeable + assert not writeable + writeable = None if writeable is None: writeable = is_writeable_array(x) @@ -328,14 +337,14 @@ def _op( raise ValueError(msg) if in_place_op: # add(), subtract(), ... - x[self._idx] = in_place_op(x[self._idx], y) + x[idx] = in_place_op(x[idx], y) else: # set() - x[self._idx] = y + x[idx] = y return x def set( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -345,7 +354,7 @@ def set( def add( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -359,7 +368,7 @@ def add( def subtract( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -371,7 +380,7 @@ def subtract( def multiply( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -383,7 +392,7 @@ def multiply( def divide( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -395,7 +404,7 @@ def divide( def power( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -405,7 +414,7 @@ def power( def min( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -417,7 +426,7 @@ def min( def max( self, - y: Array | object, + y: Array | complex, /, copy: bool | None = None, xp: ModuleType | None = None, diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index cf06dd55..dbb6bc14 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -7,12 +7,12 @@ import warnings from collections.abc import Sequence from types import ModuleType -from typing import TYPE_CHECKING, cast +from typing import cast from ._at import at from ._utils import _compat, _helpers from ._utils._compat import array_namespace, is_jax_array -from ._utils._helpers import asarrays, ndindex +from ._utils._helpers import asarrays, eager_shape, ndindex from ._utils._typing import Array __all__ = [ @@ -211,11 +211,13 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: m = xp.astype(m, dtype) avg = _helpers.mean(m, axis=1, xp=xp) - fact = m.shape[1] - 1 + + m_shape = eager_shape(m) + fact = m_shape[1] - 1 if fact <= 0: warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) - fact = 0.0 + fact = 0 m -= avg[:, None] m_transpose = m.T @@ -274,8 +276,10 @@ def create_diagonal( if x.ndim == 0: err_msg = "`x` must be at least 1-dimensional." raise ValueError(err_msg) - batch_dims = x.shape[:-1] - n = x.shape[-1] + abs(offset) + + x_shape = eager_shape(x) + batch_dims = x_shape[:-1] + n = x_shape[-1] + abs(offset) diag = xp.zeros((*batch_dims, n**2), dtype=x.dtype, device=_compat.device(x)) target_slice = slice( @@ -385,10 +389,6 @@ def isclose( ) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in array_api_extra._delegation.""" a, b = asarrays(a, b, xp=xp) - # FIXME https://github.com/microsoft/pyright/issues/10085 - if TYPE_CHECKING: # pragma: nocover - assert _compat.is_array_api_obj(a) - assert _compat.is_array_api_obj(b) a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating")) b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating")) @@ -505,24 +505,17 @@ def kron( if xp is None: xp = array_namespace(a, b) a, b = asarrays(a, b, xp=xp) - # FIXME https://github.com/microsoft/pyright/issues/10085 - if TYPE_CHECKING: # pragma: nocover - assert _compat.is_array_api_obj(a) - assert _compat.is_array_api_obj(b) singletons = (1,) * (b.ndim - a.ndim) - a = xp.broadcast_to(a, singletons + a.shape) - # FIXME https://github.com/microsoft/pyright/issues/10085 - if TYPE_CHECKING: # pragma: nocover - assert _compat.is_array_api_obj(a) + a = cast(Array, xp.broadcast_to(a, singletons + a.shape)) nd_b, nd_a = b.ndim, a.ndim nd_max = max(nd_b, nd_a) if nd_a == 0 or nd_b == 0: return xp.multiply(a, b) - a_shape = a.shape - b_shape = b.shape + a_shape = eager_shape(a) + b_shape = eager_shape(b) # Equalise the shapes by prepending smaller one with 1s a_shape = (1,) * max(0, nd_b - nd_a) + a_shape @@ -587,16 +580,14 @@ def pad( ) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in `array_api_extra._delegation.py`.""" # make pad_width a list of length-2 tuples of ints - x_ndim = cast(int, x.ndim) - if isinstance(pad_width, int): - pad_width_seq = [(pad_width, pad_width)] * x_ndim + pad_width_seq = [(pad_width, pad_width)] * x.ndim elif ( isinstance(pad_width, tuple) and len(pad_width) == 2 and all(isinstance(i, int) for i in pad_width) ): - pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim + pad_width_seq = [cast(tuple[int, int], pad_width)] * x.ndim else: pad_width_seq = cast(list[tuple[int, int]], list(pad_width)) @@ -608,7 +599,8 @@ def pad( msg = f"expect a 2-tuple (before, after), got {w_tpl}." raise ValueError(msg) - sh = x.shape[ax] + sh = eager_shape(x)[ax] + if w_tpl[0] == 0 and w_tpl[1] == 0: sl = slice(None, None, None) else: @@ -674,20 +666,17 @@ def setdiff1d( """ if xp is None: xp = array_namespace(x1, x2) - x1, x2 = asarrays(x1, x2, xp=xp) + # FIXME https://github.com/microsoft/pyright/issues/10103 + x1_, x2_ = asarrays(x1, x2, xp=xp) if assume_unique: - x1 = xp.reshape(x1, (-1,)) - x2 = xp.reshape(x2, (-1,)) + x1_ = xp.reshape(x1_, (-1,)) + x2_ = xp.reshape(x2_, (-1,)) else: - x1 = xp.unique_values(x1) - x2 = xp.unique_values(x2) - - # FIXME https://github.com/microsoft/pyright/issues/10085 - if TYPE_CHECKING: # pragma: nocover - assert _compat.is_array_api_obj(x1) + x1_ = xp.unique_values(x1_) + x2_ = xp.unique_values(x2_) - return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] + return x1_[_helpers.in1d(x1_, x2_, assume_unique=True, invert=True, xp=xp)] def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 3cd72942..c6be5e97 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -7,6 +7,7 @@ import math from types import ModuleType +from typing import cast import pytest @@ -48,10 +49,11 @@ def _check_ns_shape_dtype( actual_shape = actual.shape desired_shape = desired.shape if is_dask_namespace(desired_xp): - if any(math.isnan(i) for i in actual_shape): - actual_shape = actual.compute().shape - if any(math.isnan(i) for i in desired_shape): - desired_shape = desired.compute().shape + # Dask uses nan instead of None for unknown shapes + if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)): + actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)): + desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] msg = f"shapes do not match: {actual_shape} != f{desired_shape}" assert actual_shape == desired_shape, msg @@ -100,11 +102,11 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: import numpy as np # pylint: disable=import-outside-toplevel if is_pydata_sparse_namespace(xp): - actual = actual.todense() - desired = desired.todense() + actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] # JAX uses `np.testing` - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) + np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] def xp_assert_close( @@ -164,13 +166,17 @@ def xp_assert_close( import numpy as np # pylint: disable=import-outside-toplevel if is_pydata_sparse_namespace(xp): - actual = actual.to_dense() - desired = desired.to_dense() + actual = actual.to_dense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + desired = desired.to_dense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] # JAX uses `np.testing` assert isinstance(rtol, float) - np.testing.assert_allclose( - actual, desired, rtol=rtol, atol=atol, err_msg=err_msg + np.testing.assert_allclose( # pyright: ignore[reportCallIssue] + actual, # pyright: ignore[reportArgumentType] + desired, # pyright: ignore[reportArgumentType] + rtol=rtol, + atol=atol, + err_msg=err_msg, # type: ignore[call-overload] ) diff --git a/src/array_api_extra/_lib/_utils/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi index 66134fae..f40d7556 100644 --- a/src/array_api_extra/_lib/_utils/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -29,12 +29,12 @@ def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_cupy_array(x: object, /) -> bool: ... -def is_dask_array(x: object, /) -> bool: ... -def is_jax_array(x: object, /) -> bool: ... -def is_numpy_array(x: object, /) -> bool: ... -def is_pydata_sparse_array(x: object, /) -> bool: ... -def is_torch_array(x: object, /) -> bool: ... -def is_lazy_array(x: object, /) -> bool: ... -def is_writeable_array(x: object, /) -> bool: ... +def is_cupy_array(x: object, /) -> TypeIs[Array]: ... +def is_dask_array(x: object, /) -> TypeIs[Array]: ... +def is_jax_array(x: object, /) -> TypeIs[Array]: ... +def is_numpy_array(x: object, /) -> TypeIs[Array]: ... +def is_pydata_sparse_array(x: object, /) -> TypeIs[Array]: ... +def is_torch_array(x: object, /) -> TypeIs[Array]: ... +def is_lazy_array(x: object, /) -> TypeIs[Array]: ... +def is_writeable_array(x: object, /) -> TypeIs[Array]: ... def size(x: Array, /) -> int | None: ... diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 594b6e12..b0e39d06 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -3,9 +3,10 @@ # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 from __future__ import annotations -from collections.abc import Generator +import math +from collections.abc import Generator, Iterable from types import ModuleType -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from . import _compat from ._compat import array_namespace, is_array_api_obj, is_numpy_array @@ -16,7 +17,7 @@ from typing_extensions import TypeIs -__all__ = ["asarrays", "in1d", "is_python_scalar", "mean"] +__all__ = ["asarrays", "eager_shape", "in1d", "is_python_scalar", "mean"] def in1d( @@ -41,14 +42,17 @@ def in1d( if xp is None: xp = array_namespace(x1, x2) + x1_shape = eager_shape(x1) + x2_shape = eager_shape(x2) + # This code is run to make the code significantly faster - if x2.shape[0] < 10 * x1.shape[0] ** 0.145: + if x2_shape[0] < 10 * x1_shape[0] ** 0.145 and isinstance(x2, Iterable): if invert: - mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1)) + mask = xp.ones(x1_shape[0], dtype=xp.bool, device=_compat.device(x1)) for a in x2: mask &= x1 != a else: - mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1)) + mask = xp.zeros(x1_shape[0], dtype=xp.bool, device=_compat.device(x1)) for a in x2: mask |= x1 == a return mask @@ -146,7 +150,8 @@ def asarrays( a_scalar = is_python_scalar(a) b_scalar = is_python_scalar(b) if not a_scalar and not b_scalar: - return a, b # This includes misc. malformed input e.g. str + # This includes misc. malformed input e.g. str + return a, b # type: ignore[return-value] swap = False if a_scalar: @@ -165,7 +170,7 @@ def asarrays( float: ("real floating", "complex floating"), complex: "complex floating", } - kind = same_dtype[type(b)] # type: ignore[index] + kind = same_dtype[type(cast(complex, b))] # type: ignore[index] if xp.isdtype(a.dtype, kind): xb = xp.asarray(b, dtype=a.dtype) else: @@ -203,3 +208,25 @@ def ndindex(*x: int) -> Generator[tuple[int, ...]]: for i in ndindex(*x[:-1]): for j in range(x[-1]): yield *i, j + + +def eager_shape(x: Array, /) -> tuple[int, ...]: + """ + Return shape of an array. Raise if shape is not fully defined. + + Parameters + ---------- + x : Array + Input array. + + Returns + ------- + tuple[int, ...] + Shape of the array. + """ + shape = x.shape + # Dask arrays uses non-standard NaN instead of None + if any(s is None or math.isnan(s) for s in shape): + msg = "Unsupported lazy shape" + raise TypeError(msg) + return cast(tuple[int, ...], shape) diff --git a/src/array_api_extra/_lib/_utils/_typing.py b/src/array_api_extra/_lib/_utils/_typing.py index 95f29f79..68577c9a 100644 --- a/src/array_api_extra/_lib/_utils/_typing.py +++ b/src/array_api_extra/_lib/_utils/_typing.py @@ -1,11 +1,22 @@ -"""Static typing helpers.""" +# pylint: disable=missing-module-docstring # numpydoc ignore=GL08 +class Array: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 + pass -from typing import Any -# To be changed to a Protocol later (see data-apis/array-api#589) -Array = Any # type: ignore[no-any-explicit] -Device = Any # type: ignore[no-any-explicit] -DType = Any # type: ignore[no-any-explicit] -Index = Any # type: ignore[no-any-explicit] +class DType: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 + pass -__all__ = ["Array", "DType", "Device", "Index"] + +class Device: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 + pass + + +class GetIndex: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 + pass + + +class SetIndex: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 + pass + + +__all__ = ["Array", "DType", "Device", "GetIndex", "SetIndex"] diff --git a/src/array_api_extra/_lib/_utils/_typing.pyi b/src/array_api_extra/_lib/_utils/_typing.pyi new file mode 100644 index 00000000..9ef06162 --- /dev/null +++ b/src/array_api_extra/_lib/_utils/_typing.pyi @@ -0,0 +1,105 @@ +"""Static typing helpers.""" + +from __future__ import annotations + +from types import EllipsisType +from typing import Protocol, TypeAlias + +# TODO import from typing (requires Python >=3.12) +from typing_extensions import override + +# TODO: use array-api-typing once it is available + +class Array(Protocol): # pylint: disable=missing-class-docstring + # Unary operations + def __abs__(self) -> Array: ... + def __pos__(self) -> Array: ... + def __neg__(self) -> Array: ... + def __invert__(self) -> Array: ... + # Binary operations + def __add__(self, other: Array | complex, /) -> Array: ... + def __sub__(self, other: Array | complex, /) -> Array: ... + def __mul__(self, other: Array | complex, /) -> Array: ... + def __truediv__(self, other: Array | complex, /) -> Array: ... + def __floordiv__(self, other: Array | complex, /) -> Array: ... + def __mod__(self, other: Array | complex, /) -> Array: ... + def __pow__(self, other: Array | complex, /) -> Array: ... + def __matmul__(self, other: Array, /) -> Array: ... + def __and__(self, other: Array | int, /) -> Array: ... + def __or__(self, other: Array | int, /) -> Array: ... + def __xor__(self, other: Array | int, /) -> Array: ... + def __lshift__(self, other: Array | int, /) -> Array: ... + def __rshift__(self, other: Array | int, /) -> Array: ... + def __lt__(self, other: Array | complex, /) -> Array: ... + def __le__(self, other: Array | complex, /) -> Array: ... + def __gt__(self, other: Array | complex, /) -> Array: ... + def __ge__(self, other: Array | complex, /) -> Array: ... + @override + def __eq__(self, other: Array | complex, /) -> Array: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + @override + def __ne__(self, other: Array | complex, /) -> Array: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] + # Reflected operations + def __radd__(self, other: Array | complex, /) -> Array: ... + def __rsub__(self, other: Array | complex, /) -> Array: ... + def __rmul__(self, other: Array | complex, /) -> Array: ... + def __rtruediv__(self, other: Array | complex, /) -> Array: ... + def __rfloordiv__(self, other: Array | complex, /) -> Array: ... + def __rmod__(self, other: Array | complex, /) -> Array: ... + def __rpow__(self, other: Array | complex, /) -> Array: ... + def __rmatmul__(self, other: Array, /) -> Array: ... + def __rand__(self, other: Array | int, /) -> Array: ... + def __ror__(self, other: Array | int, /) -> Array: ... + def __rxor__(self, other: Array | int, /) -> Array: ... + def __rlshift__(self, other: Array | int, /) -> Array: ... + def __rrshift__(self, other: Array | int, /) -> Array: ... + # Attributes + @property + def dtype(self) -> DType: ... + @property + def device(self) -> Device: ... + @property + def mT(self) -> Array: ... # pylint: disable=invalid-name + @property + def ndim(self) -> int: ... + @property + def shape(self) -> tuple[int | None, ...]: ... + @property + def size(self) -> int | None: ... + @property + def T(self) -> Array: ... # pylint: disable=invalid-name + # Collection operations (note: an Array does not have to be Sized or Iterable) + def __getitem__(self, key: GetIndex, /) -> Array: ... + def __setitem__(self, key: SetIndex, value: Array | complex, /) -> None: ... + # Materialization methods (may raise on lazy arrays) + def __bool__(self) -> bool: ... + def __complex__(self) -> complex: ... + def __float__(self) -> float: ... + def __index__(self) -> int: ... + def __int__(self) -> int: ... + + # Misc methods (frequently not implemented in Arrays wrapped by array-api-compat) + # def __array_namespace__(*, api_version: str | None) -> ModuleType: ... + # def __dlpack__( + # *, + # stream: int | Any | None = None, + # max_version: tuple[int, int] | None = None, + # dl_device: tuple[int, int] | None = None, # tuple[Enum, int] + # copy: bool | None = None, + # ) -> Any: ... + # def __dlpack_device__() -> tuple[int, int]: ... # tuple[Enum, int] + # def to_device(device: Device, /, *, stream: int | Any | None = None) -> Array: ... + +class DType(Protocol): # pylint: disable=missing-class-docstring + pass + +class Device(Protocol): # pylint: disable=missing-class-docstring + pass + +SetIndex: TypeAlias = ( # type: ignore[no-any-explicit] + int | slice | EllipsisType | Array | tuple[int | slice | EllipsisType | Array, ...] +) +GetIndex: TypeAlias = ( # type: ignore[no-any-explicit] + SetIndex | None | tuple[int | slice | EllipsisType | None | Array, ...] +) + +__all__ = ["Array", "DType", "Device", "GetIndex", "SetIndex"] diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index a0f97a81..b3782090 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -328,6 +328,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 # Block until the graph materializes and reraise exceptions. This allows # `pytest.raises` and `pytest.warns` to work as expected. Note that this would # not work on scheduler='distributed', as it would not block. - return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] + return dask.persist(out, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index 5ba6dca6..6cb4e433 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ T = TypeVar("T") P = ParamSpec("P") -np_compat = array_namespace(np.empty(0)) +np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] @pytest.fixture(params=tuple(Backend)) diff --git a/tests/test_at.py b/tests/test_at.py index e13a691a..9484a481 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -13,7 +13,7 @@ from array_api_extra._lib._at import _AtOp from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array -from array_api_extra._lib._utils._typing import Array, Index +from array_api_extra._lib._utils._typing import Array, SetIndex from array_api_extra.testing import lazy_xp_function pytestmark = [ @@ -25,7 +25,7 @@ def at_op( x: Array, - idx: Index, + idx: SetIndex, op: _AtOp, y: Array | object, copy: bool | None = None, @@ -46,7 +46,7 @@ def at_op( def _at_op( x: Array, - idx: Index | None, + idx: SetIndex | None, idx_pickle: bytes | None, op: _AtOp, y: Array | object, @@ -56,7 +56,7 @@ def _at_op( """jitted helper of at_op""" if idx_pickle: idx = pickle.loads(idx_pickle) - meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit] + meth = cast(Callable[..., Array], getattr(at(x, cast(SetIndex, idx)), op.value)) # type: ignore[no-any-explicit] return meth(y, copy=copy, xp=xp) @@ -183,34 +183,35 @@ def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp): def test_copy_invalid(): a = np.asarray([1, 2, 3]) with pytest.raises(ValueError, match="copy"): - at(a, 0).set(4, copy="invalid") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + _ = at(a, 0).set(4, copy="invalid") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] def test_xp(): - a = np.asarray([1, 2, 3]) - at(a, 0).set(4, xp=np) - at(a, 0).add(4, xp=np) - at(a, 0).subtract(4, xp=np) - at(a, 0).multiply(4, xp=np) - at(a, 0).divide(4, xp=np) - at(a, 0).power(4, xp=np) - at(a, 0).min(4, xp=np) - at(a, 0).max(4, xp=np) + a = cast(Array, np.asarray([1, 2, 3])) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast] + _ = at(a, 0).set(4, xp=np) + _ = at(a, 0).add(4, xp=np) + _ = at(a, 0).subtract(4, xp=np) + _ = at(a, 0).multiply(4, xp=np) + _ = at(a, 0).divide(4, xp=np) + _ = at(a, 0).power(4, xp=np) + _ = at(a, 0).min(4, xp=np) + _ = at(a, 0).max(4, xp=np) def test_alternate_index_syntax(): - a = np.asarray([1, 2, 3]) - xp_assert_equal(at(a, 0).set(4, copy=True), np.asarray([4, 2, 3])) - xp_assert_equal(at(a)[0].set(4, copy=True), np.asarray([4, 2, 3])) + xp = cast(ModuleType, np) # pyright: ignore[reportInvalidCast] + a = cast(Array, xp.asarray([1, 2, 3])) + xp_assert_equal(at(a, 0).set(4, copy=True), xp.asarray([4, 2, 3])) + xp_assert_equal(at(a)[0].set(4, copy=True), xp.asarray([4, 2, 3])) a_at = at(a) - xp_assert_equal(a_at[0].add(1, copy=True), np.asarray([2, 2, 3])) - xp_assert_equal(a_at[1].add(2, copy=True), np.asarray([1, 4, 3])) + xp_assert_equal(a_at[0].add(1, copy=True), xp.asarray([2, 2, 3])) + xp_assert_equal(a_at[1].add(2, copy=True), xp.asarray([1, 4, 3])) with pytest.raises(ValueError, match="Index"): - at(a).set(4) + _ = at(a).set(4) with pytest.raises(ValueError, match="Index"): - at(a, 0)[0].set(4) + _ = at(a, 0)[0].set(4) @pytest.mark.parametrize("copy", [True, None]) @@ -256,7 +257,7 @@ def test_incompatible_dtype( elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET: with pytest.raises(Exception, match=r"cast|promote|dtype"): - at_op(x, idx, op, 1.1, copy=copy) + _ = at_op(x, idx, op, 1.1, copy=copy) elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX): # There is no __i__ version of these operations @@ -264,7 +265,7 @@ def test_incompatible_dtype( else: with pytest.raises(Exception, match=r"cast|promote|dtype"): - at_op(x, idx, op, 1.1, copy=copy) + _ = at_op(x, idx, op, 1.1, copy=copy) assert z is None or z.dtype == x.dtype diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 27318281..65d19aaa 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -23,7 +23,7 @@ from array_api_extra._lib import Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device -from array_api_extra._lib._utils._helpers import ndindex +from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -249,7 +249,7 @@ def test_1d_from_scipy(self, xp: ModuleType, n: int, offset: int): def test_0d_raises(self, xp: ModuleType): with pytest.raises(ValueError, match="1-dimensional"): - create_diagonal(xp.asarray(1)) + _ = create_diagonal(xp.asarray(1)) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()") @pytest.mark.parametrize( @@ -274,7 +274,7 @@ def test_nd(self, xp: ModuleType, shape: tuple[int, ...]): c = create_diagonal(b) zero = xp.zeros((), dtype=xp.uint64) assert c.shape == (*b.shape, b.shape[-1]) - for i in ndindex(*c.shape): + for i in ndindex(*eager_shape(c)): xp_assert_equal(c[i], b[i[:-1]] if i[-2] == i[-1] else zero) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()") @@ -320,26 +320,26 @@ def test_axis_out_of_range(self, xp: ModuleType): s = (2, 3, 4, 5) a = xp.empty(s) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=-6) + _ = expand_dims(a, axis=-6) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=5) + _ = expand_dims(a, axis=5) a = xp.empty((3, 3, 3)) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=(0, -6)) + _ = expand_dims(a, axis=(0, -6)) with pytest.raises(IndexError, match="out of bounds"): - expand_dims(a, axis=(0, 5)) + _ = expand_dims(a, axis=(0, 5)) def test_repeated_axis(self, xp: ModuleType): a = xp.empty((3, 3, 3)) with pytest.raises(ValueError, match="Duplicate dimensions"): - expand_dims(a, axis=(1, 1)) + _ = expand_dims(a, axis=(1, 1)) def test_positive_negative_repeated(self, xp: ModuleType): # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817 a = xp.empty((2, 3, 4, 5)) with pytest.raises(ValueError, match="Duplicate dimensions"): - expand_dims(a, axis=(3, -3)) + _ = expand_dims(a, axis=(3, -3)) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims") def test_device(self, xp: ModuleType, device: Device): @@ -505,7 +505,7 @@ def test_python_scalar(self, xp: ModuleType): def test_all_python_scalars(self): with pytest.raises(TypeError, match="Unrecognized"): - isclose(0, 0) + _ = isclose(0, 0) def test_xp(self, xp: ModuleType): a = xp.asarray([0.0, 0.0]) @@ -582,7 +582,7 @@ def test_python_scalar(self, xp: ModuleType): def test_all_python_scalars(self): with pytest.raises(TypeError, match="Unrecognized"): - kron(1, 1) + _ = kron(1, 1) def test_device(self, xp: ModuleType, device: Device): x1 = xp.asarray([1, 2, 3], device=device) @@ -634,7 +634,7 @@ def test_ndim(self, xp: ModuleType): def test_mode_not_implemented(self, xp: ModuleType): a = xp.arange(3) with pytest.raises(NotImplementedError, match="Only `'constant'`"): - pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + _ = pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] def test_device(self, xp: ModuleType, device: Device): a = xp.asarray(0.0, device=device) @@ -653,7 +653,7 @@ def test_tuple_width(self, xp: ModuleType): assert padded.shape == (6, 7) with pytest.raises((ValueError, RuntimeError)): - pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] + _ = pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] def test_sequence_of_tuples_width(self, xp: ModuleType): a = xp.reshape(xp.arange(12), (3, 4)) @@ -745,7 +745,7 @@ def test_python_scalar(self, xp: ModuleType, assume_unique: bool): @pytest.mark.parametrize("assume_unique", [True, False]) def test_all_python_scalars(self, assume_unique: bool): with pytest.raises(TypeError, match="Unrecognized"): - setdiff1d(0, 0, assume_unique=assume_unique) + _ = setdiff1d(0, 0, assume_unique=assume_unique) @assume_unique def test_device(self, xp: ModuleType, device: Device, assume_unique: bool): @@ -773,7 +773,7 @@ def test_simple(self, xp: ModuleType): @pytest.mark.parametrize("x", [0, 1 + 3j]) def test_dtype(self, xp: ModuleType, x: int | complex): with pytest.raises(ValueError, match="real floating data type"): - sinc(xp.asarray(x)) + _ = sinc(xp.asarray(x)) def test_3d(self, xp: ModuleType): x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2)) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 1710ff84..9895e2c5 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,4 +1,5 @@ from types import ModuleType +from typing import cast import numpy as np import pytest @@ -6,8 +7,8 @@ from array_api_extra._lib import Backend from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device -from array_api_extra._lib._utils._helpers import asarrays, in1d, ndindex -from array_api_extra._lib._utils._typing import Device +from array_api_extra._lib._utils._helpers import asarrays, eager_shape, in1d, ndindex +from array_api_extra._lib._utils._typing import Array, Device, DType from array_api_extra.testing import lazy_xp_function # mypy: disable-error-code=no-untyped-usage @@ -139,12 +140,12 @@ def test_array_vs_array(self, a_type: str, b_type: str, xp: ModuleType): assert xb.dtype == b.dtype @pytest.mark.parametrize("dtype", [np.float64, np.complex128]) - def test_numpy_generics(self, dtype: type): + def test_numpy_generics(self, dtype: DType): """ Test special case of np.float64 and np.complex128, which are subclasses of float and complex. """ - a = dtype(0) + a = cast(Array, dtype(0)) # type: ignore[operator] # pyright: ignore[reportCallIssue] xa, xb = asarrays(a, 0, xp=np) assert xa.dtype == dtype assert xb.dtype == dtype @@ -155,3 +156,20 @@ def test_numpy_generics(self, dtype: type): ) def test_ndindex(shape: tuple[int, ...]): assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape)) + + +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") +def test_eager_shape(xp: ModuleType, library: Backend): + a = xp.asarray([1, 2, 3]) + # Lazy arrays, like Dask, have an eager shape until you slice them with + # a lazy boolean mask + assert eager_shape(a) == a.shape == (3,) + + b = a[a > 2] + if library is Backend.DASK: + with pytest.raises(TypeError, match="Unsupported lazy shape"): + _ = eager_shape(b) + # FIXME can't test use case for None in the shape until we add support for + # other lazy backends + else: + assert eager_shape(b) == b.shape == (1,) diff --git a/tests/test_testing.py b/tests/test_testing.py index ed21feb2..b5ec7d85 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,5 +1,6 @@ from collections.abc import Callable from types import ModuleType +from typing import cast import numpy as np import pytest @@ -151,19 +152,19 @@ def test_lazy_xp_function(xp: ModuleType): with pytest.raises( TypeError, match="Attempted boolean conversion of traced array" ): - non_materializable4(x) # Wrapped + _ = non_materializable4(x) # Wrapped elif is_dask_namespace(xp): with pytest.raises( AssertionError, match=r"dask\.compute.* 2 times, but only up to 1 calls are allowed", ): - non_materializable3(x) + _ = non_materializable3(x) with pytest.raises( AssertionError, match=r"dask\.compute.* 1 times, but no calls are allowed", ): - non_materializable4(x) + _ = non_materializable4(x) else: xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0])) @@ -227,12 +228,12 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): # eager jax arrays are auto-converted to numpy in eager jax # and fail in jax.jit (which lazy_xp_function tests here) with pytest.raises((TypeError, AssertionError)): - xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) + xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0])) else: # cupy, dask and sparse define __array_ufunc__ and dispatch accordingly # note that when sparse reduces to scalar it returns a np.generic, which # would make xp_assert_equal fail. - xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) + xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0])) def dask_raises(x: Array) -> Array: @@ -243,7 +244,7 @@ def _raises(x: Array) -> Array: msg = "Hello world" raise ValueError(msg) - return x.map_blocks(_raises, dtype=x.dtype, meta=x._meta) + return x.map_blocks(_raises, dtype=x.dtype, meta=x._meta) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] lazy_xp_function(dask_raises) @@ -261,40 +262,44 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType): """ x = da.arange(3) with pytest.raises(ValueError, match="Hello world"): - dask_raises(x) + _ = dask_raises(x) -class Wrapped: - def f(x: Array) -> Array: # noqa: N805 # pyright: ignore[reportSelfClsParameterName] - xp = array_namespace(x) - # Crash in jax.jit and trigger compute() on dask - if not xp.all(x): - msg = "Values must be non-zero" - raise ValueError(msg) - return x +wrapped = ModuleType("wrapped") +naked = ModuleType("naked") -class Naked: - f = Wrapped.f # pyright: ignore[reportUnannotatedClassAttribute] +def f(x: Array) -> Array: + xp = array_namespace(x) + # Crash in jax.jit and trigger compute() on dask + if not xp.all(x): + msg = "Values must be non-zero" + raise ValueError(msg) + return x + + +wrapped.f = f # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] +naked.f = f # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] +del f -lazy_xp_function(Wrapped.f) -lazy_xp_modules = [Wrapped] +lazy_xp_function(wrapped.f) +lazy_xp_modules = [wrapped] def test_lazy_xp_modules(xp: ModuleType, library: Backend): x = xp.asarray([1.0, 2.0]) - y = Naked.f(x) + y = naked.f(x) xp_assert_equal(y, x) if library is Backend.JAX: with pytest.raises( TypeError, match="Attempted boolean conversion of traced array" ): - Wrapped.f(x) + wrapped.f(x) elif library is Backend.DASK: with pytest.raises(AssertionError, match=r"dask\.compute"): - Wrapped.f(x) + wrapped.f(x) else: - y = Wrapped.f(x) + y = wrapped.f(x) xp_assert_equal(y, x) From 94c2a122ad9e7abd1f20192fd5d8885bc0abd4ea Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 17 Mar 2025 17:38:48 +0000 Subject: [PATCH 2/3] Update src/array_api_extra/_lib/_funcs.py --- src/array_api_extra/_lib/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index dbb6bc14..43698c42 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -666,7 +666,7 @@ def setdiff1d( """ if xp is None: xp = array_namespace(x1, x2) - # FIXME https://github.com/microsoft/pyright/issues/10103 + # https://github.com/microsoft/pyright/issues/10103 x1_, x2_ = asarrays(x1, x2, xp=xp) if assume_unique: From ee48ad6452a0166eec624a641af45f88df6666d4 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 17 Mar 2025 17:46:42 +0000 Subject: [PATCH 3/3] More compact _typing.py --- src/array_api_extra/_lib/_utils/_typing.py | 28 +++++++--------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/array_api_extra/_lib/_utils/_typing.py b/src/array_api_extra/_lib/_utils/_typing.py index 68577c9a..d32a3a07 100644 --- a/src/array_api_extra/_lib/_utils/_typing.py +++ b/src/array_api_extra/_lib/_utils/_typing.py @@ -1,22 +1,10 @@ -# pylint: disable=missing-module-docstring # numpydoc ignore=GL08 -class Array: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 - pass - - -class DType: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 - pass - - -class Device: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 - pass - - -class GetIndex: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 - pass - - -class SetIndex: # pylint: disable=missing-class-docstring # numpydoc ignore=GL08 - pass - +# numpydoc ignore=GL08 +# pylint: disable=missing-module-docstring + +Array = object +DType = object +Device = object +GetIndex = object +SetIndex = object __all__ = ["Array", "DType", "Device", "GetIndex", "SetIndex"]