Skip to content

TYP: Full annotations for Array objects #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ enable_error_code = ["ignore-without-code", "truthy-bool"]
# https://github.com/data-apis/array-api-typing
disallow_any_expr = false
# false positives with input validation
disable_error_code = ["redundant-expr", "unreachable"]
disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]

[[tool.mypy.overrides]]
# slow/unavailable on Windows; do not add to the lint env
Expand Down
59 changes: 34 additions & 25 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_jax_array,
is_writeable_array,
)
from ._utils._typing import Array, Index
from ._utils._typing import Array, SetIndex


class _AtOp(Enum):
Expand Down Expand Up @@ -43,7 +43,13 @@ def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[
return self.value


_undef = object()
class Undef(Enum):
"""Sentinel for undefined values."""

UNDEF = 0


_undef = Undef.UNDEF


class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
Expand Down Expand Up @@ -188,16 +194,16 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
"""

_x: Array
_idx: Index
_idx: SetIndex | Undef
__slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x")

def __init__(
self, x: Array, idx: Index = _undef, /
self, x: Array, idx: SetIndex | Undef = _undef, /
) -> None: # numpydoc ignore=GL08
self._x = x
self._idx = idx

def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
def __getitem__(self, idx: SetIndex, /) -> at: # numpydoc ignore=PR01,RT01
"""
Allow for the alternate syntax ``at(x)[start:stop:step]``.

Expand All @@ -212,9 +218,9 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
def _op(
self,
at_op: _AtOp,
in_place_op: Callable[[Array, Array | object], Array] | None,
in_place_op: Callable[[Array, Array | complex], Array] | None,
out_of_place_op: Callable[[Array, Array], Array] | None,
y: Array | object,
y: Array | complex,
/,
copy: bool | None,
xp: ModuleType | None,
Expand All @@ -226,7 +232,7 @@ def _op(
----------
at_op : _AtOp
Method of JAX's Array.at[].
in_place_op : Callable[[Array, Array | object], Array] | None
in_place_op : Callable[[Array, Array | complex], Array] | None
In-place operation to apply on mutable backends::

x[idx] = in_place_op(x[idx], y)
Expand All @@ -245,7 +251,7 @@ def _op(

x = xp.where(idx, y, x)

y : array or object
y : array or complex
Right-hand side of the operation.
copy : bool or None
Whether to copy the input array. See the class docstring for details.
Expand All @@ -260,7 +266,7 @@ def _op(
x, idx = self._x, self._idx
xp = array_namespace(x, y) if xp is None else xp

if idx is _undef:
if isinstance(idx, Undef):
msg = (
"Index has not been set.\n"
"Usage: either\n"
Expand Down Expand Up @@ -306,7 +312,10 @@ def _op(
if copy or (copy is None and not writeable):
if is_jax_array(x):
# Use JAX's at[]
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
func = cast(
Callable[[Array | complex], Array],
getattr(x.at[idx], at_op.value), # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue,reportUnknownArgumentType]
)
out = func(y)
# Undo int->float promotion on JAX after _AtOp.DIVIDE
return xp.astype(out, x.dtype, copy=False)
Expand All @@ -315,10 +324,10 @@ def _op(
# with a copy followed by an update

x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
writeable = None
# A copy of a read-only numpy array is writeable
# Note: this assumes that a copy of a writeable array is writeable
assert not writeable
writeable = None

if writeable is None:
writeable = is_writeable_array(x)
Expand All @@ -328,14 +337,14 @@ def _op(
raise ValueError(msg)

if in_place_op: # add(), subtract(), ...
x[self._idx] = in_place_op(x[self._idx], y)
x[idx] = in_place_op(x[idx], y)
else: # set()
x[self._idx] = y
x[idx] = y
return x

def set(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -345,7 +354,7 @@ def set(

def add(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -359,7 +368,7 @@ def add(

def subtract(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -371,7 +380,7 @@ def subtract(

def multiply(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -383,7 +392,7 @@ def multiply(

def divide(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -395,7 +404,7 @@ def divide(

def power(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -405,7 +414,7 @@ def power(

def min(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -417,7 +426,7 @@ def min(

def max(
self,
y: Array | object,
y: Array | complex,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand Down
59 changes: 24 additions & 35 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import warnings
from collections.abc import Sequence
from types import ModuleType
from typing import TYPE_CHECKING, cast
from typing import cast

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._helpers import asarrays, ndindex
from ._utils._helpers import asarrays, eager_shape, ndindex
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -211,11 +211,13 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
m = xp.astype(m, dtype)

avg = _helpers.mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

m_shape = eager_shape(m)
fact = m_shape[1] - 1

if fact <= 0:
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
fact = 0.0
fact = 0

m -= avg[:, None]
m_transpose = m.T
Expand Down Expand Up @@ -274,8 +276,10 @@ def create_diagonal(
if x.ndim == 0:
err_msg = "`x` must be at least 1-dimensional."
raise ValueError(err_msg)
batch_dims = x.shape[:-1]
n = x.shape[-1] + abs(offset)

x_shape = eager_shape(x)
batch_dims = x_shape[:-1]
n = x_shape[-1] + abs(offset)
diag = xp.zeros((*batch_dims, n**2), dtype=x.dtype, device=_compat.device(x))

target_slice = slice(
Expand Down Expand Up @@ -385,10 +389,6 @@ def isclose(
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
a, b = asarrays(a, b, xp=xp)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
Expand Down Expand Up @@ -505,24 +505,17 @@ def kron(
if xp is None:
xp = array_namespace(a, b)
a, b = asarrays(a, b, xp=xp)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(a, singletons + a.shape)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
a = cast(Array, xp.broadcast_to(a, singletons + a.shape))

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
if nd_a == 0 or nd_b == 0:
return xp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape
a_shape = eager_shape(a)
b_shape = eager_shape(b)

# Equalise the shapes by prepending smaller one with 1s
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
Expand Down Expand Up @@ -587,16 +580,14 @@ def pad(
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
# make pad_width a list of length-2 tuples of ints
x_ndim = cast(int, x.ndim)

if isinstance(pad_width, int):
pad_width_seq = [(pad_width, pad_width)] * x_ndim
pad_width_seq = [(pad_width, pad_width)] * x.ndim
elif (
isinstance(pad_width, tuple)
and len(pad_width) == 2
and all(isinstance(i, int) for i in pad_width)
):
pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim
pad_width_seq = [cast(tuple[int, int], pad_width)] * x.ndim
else:
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))

Expand All @@ -608,7 +599,8 @@ def pad(
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
raise ValueError(msg)

sh = x.shape[ax]
sh = eager_shape(x)[ax]

if w_tpl[0] == 0 and w_tpl[1] == 0:
sl = slice(None, None, None)
else:
Expand Down Expand Up @@ -674,20 +666,17 @@ def setdiff1d(
"""
if xp is None:
xp = array_namespace(x1, x2)
x1, x2 = asarrays(x1, x2, xp=xp)
# https://github.com/microsoft/pyright/issues/10103
x1_, x2_ = asarrays(x1, x2, xp=xp)

if assume_unique:
x1 = xp.reshape(x1, (-1,))
x2 = xp.reshape(x2, (-1,))
x1_ = xp.reshape(x1_, (-1,))
x2_ = xp.reshape(x2_, (-1,))
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)

# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(x1)
x1_ = xp.unique_values(x1_)
x2_ = xp.unique_values(x2_)

return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
return x1_[_helpers.in1d(x1_, x2_, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down
Loading