Skip to content

POC: appease linter for gh-53 #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/test-vendor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ jobs:
- name: Checkout array-api-compat
uses: actions/checkout@v4
with:
repository: data-apis/array-api-compat
# DNM
# repository: data-apis/array-api-compat
repository: crusaderky/array-api-compat
ref: d7ab986843cc9eb20882d7ccbf7248d78fcbd759
# /DNM
path: array-api-compat

- name: Vendor array-api-extra into test package
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated

at
atleast_nd
cov
create_diagonal
Expand Down
105 changes: 59 additions & 46 deletions pixi.lock

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["array-api-compat>=1.1.1"]
# DNM
# dependencies = ["array-api-compat>=1.1.1"]
dependencies = []

[project.optional-dependencies]
tests = [
Expand Down Expand Up @@ -63,9 +65,12 @@ platforms = ["linux-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10.15,<3.14"
array-api-compat = ">=1.1.1"
# array-api-compat = ">=1.1.1" # DNM

[tool.pixi.pypi-dependencies]
# DNM main plus #205, #207, #211
array-api-compat = { git = "https://github.com/crusaderky/array-api-compat.git", rev = "d7ab986843cc9eb20882d7ccbf7248d78fcbd759" }

array-api-extra = { path = ".", editable = true }

[tool.pixi.feature.lint.dependencies]
Expand Down Expand Up @@ -190,6 +195,8 @@ reportAny = false
reportExplicitAny = false
# data-apis/array-api-strict#6
reportUnknownMemberType = false
# no array-api-compat type stubs
reportUnknownVariableType = false


# Ruff
Expand Down
12 changes: 11 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
from ._funcs import (
at,
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
setdiff1d,
sinc,
)

__version__ = "0.3.3.dev0"

# pylint: disable=duplicate-code
__all__ = [
"__version__",
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down
293 changes: 290 additions & 3 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990

import operator
import typing
import warnings

if typing.TYPE_CHECKING:
from ._lib._typing import Array, ModuleType
# https://github.com/pylint-dev/pylint/issues/10112
from collections.abc import Callable # pylint: disable=import-error
from typing import ClassVar

from ._lib import _utils
from ._lib._compat import array_namespace
from ._lib._compat import (
array_namespace,
is_array_api_obj,
is_dask_array,
is_writeable_array,
)

if typing.TYPE_CHECKING:
from ._lib._typing import Array, Index, ModuleType, Untyped

__all__ = [
"at",
"atleast_nd",
"cov",
"create_diagonal",
Expand Down Expand Up @@ -548,3 +559,279 @@
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
)
return xp.sin(y) / y


_undef = object()


class at: # pylint: disable=invalid-name
"""
Update operations for read-only arrays.

This implements ``jax.numpy.ndarray.at`` for all backends.

Parameters
----------
x : array
Input array.
idx : index, optional
You may use two alternate syntaxes::

at(x, idx).set(value) # or get(), add(), etc.
at(x)[idx].set(value)

copy : bool, optional
True (default)
Ensure that the inputs are not modified.
False
Ensure that the update operation writes back to the input.
Raise ValueError if a copy cannot be avoided.
None
The array parameter *may* be modified in place if it is possible and
beneficial for performance.
You should not reuse it after calling this function.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer

**kwargs:
If the backend supports an `at` method, any additional keyword
arguments are passed to it verbatim; e.g. this allows passing
``indices_are_sorted=True`` to JAX.

Returns
-------
Updated input array.

Examples
--------
Given either of these equivalent expressions::

x = at(x)[1].add(2, copy=None)
x = at(x, 1).add(2, copy=None)

If x is a JAX array, they are the same as::

x = x.at[1].add(2)

If x is a read-only numpy array, they are the same as::

x = x.copy()
x[1] += 2

Otherwise, they are the same as::

x[1] += 2

Warning
-------
When you use copy=None, you should always immediately overwrite
the parameter array::

x = at(x, 0).set(2, copy=None)

The anti-pattern below must be avoided, as it will result in different behaviour
on read-only versus writeable arrays::

x = xp.asarray([0, 0, 0])
y = at(x, 0).set(2, copy=None)
z = at(x, 1).set(3, copy=None)

In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!

Warning
-------
The array API standard does not support integer array indices.
The behaviour of update methods when the index is an array of integers
is undefined; this is particularly true when the index contains multiple
occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``.

Note
----
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.

See Also
--------
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
"""

x: Array
idx: Index
__slots__: ClassVar[tuple[str, str]] = ("idx", "x")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO the linter should not force me to define the type of __slots__, because it's part of the python data model. This only adds attrition and reduces readability.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

Copy link
Member Author

@lucascolley lucascolley Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this is no longer needed

EDIT: only if at is made final


def __init__(self, x: Array, idx: Index = _undef, /):
self.x = x
self.idx = idx

def __getitem__(self, idx: Index) -> at:
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
which looks prettier than ``at(x, slice(start, stop, step))``
and feels more intuitive coming from the JAX documentation.
"""
if self.idx is not _undef:
msg = "Index has already been set"
raise ValueError(msg)
self.idx = idx
return self

Check warning on line 675 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L671-L675

Added lines #L671 - L675 were not covered by tests

def _common(
self,
at_op: str,
y: Array = _undef,
/,
copy: bool | None = True,
xp: ModuleType | None = None,
_is_update: bool = True,
**kwargs: Untyped,
) -> tuple[Untyped, None] | tuple[None, Array]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it perhaps possible to @overload these cases?

Copy link
Contributor

@crusaderky crusaderky Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't because the return type depends on a duck-type test on x. And I'm definitely unwilling to explore writing a class HasAtMethod(Protocol) for a small internal function that is consumed exclusively 2 paragraph below.

Copy link

@jorenham jorenham Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely unwilling to explore writing a class HasAtMethod(Protocol) for a small internal function that is consumed exclusively 2 paragraph below.

I am willing, so here you go:

class _CanAt(Protocol):
    @property
    def at(self) -> Mapping[Index, Untyped] ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making this change - could you clarify how to use _CanAt @jorenham ?

"""Perform common prepocessing.

Returns
-------
If the operation can be resolved by at[], (return value, None)
Otherwise, (None, preprocessed x)
"""
if self.idx is _undef:
msg = (

Check warning on line 695 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L695

Added line #L695 was not covered by tests
"Index has not been set.\n"
"Usage: either\n"
" at(x, idx).set(value)\n"
"or\n"
" at(x)[idx].set(value)\n"
"(same for all other methods)."
)
raise TypeError(msg)

Check warning on line 703 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L703

Added line #L703 was not covered by tests

x = self.x

if copy is True:
writeable = None
elif copy is False:
writeable = is_writeable_array(x)
if not writeable:
msg = "Cannot modify parameter in place"
raise ValueError(msg)
elif copy is None: # type: ignore[redundant-expr]
writeable = is_writeable_array(x)
copy = _is_update and not writeable
else:
msg = f"Invalid value for copy: {copy!r}" # type: ignore[unreachable] # pyright: ignore[reportUnreachable]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a better way to deal with situations like this where invalid types can be passed at runtime?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This double-ignore is a strong example of why I'm strongly in favour of having only one type checker. This is tedious to both read and write.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's definitely annoying, but unfortunately it's the only possibility for ensuring compatibility with these two main type-checkers

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I agree with mypy and pyright that the else clause here is redundant

Copy link

@jorenham jorenham Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could alternatively move the if copy is None to the top, and then do elif copy: ... else: .... That way you'd also allow e.g. 0, 1, and the np.bool_ values (at runtime)

Copy link
Member Author

@lucascolley lucascolley Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we still want to throw an error at runtime if copy="foo" is passed? I guess the argument is that, since this is internal, that would always be caught by the static analysis anyway?

EDIT: ah, it isn't internal, since the methods pass**kwargs through. I suppose this can be resolved once the methods are given explicit copy and xp kwargs

raise ValueError(msg)

if copy:
try:
at_ = x.at
except AttributeError:
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
xp = array_namespace(x)
# Create writeable copy of read-only numpy array
x = xp.asarray(x, copy=True)
if writeable is False:
# A copy of a read-only numpy array is writeable
writeable = None
else:
# Use JAX's at[] or other library that with the same duck-type API
args = (y,) if y is not _undef else ()
return getattr(at_[self.idx], at_op)(*args, **kwargs), None

Check warning on line 737 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L736-L737

Added lines #L736 - L737 were not covered by tests

if _is_update:
if writeable is None:
writeable = is_writeable_array(x)
if not writeable:
# sparse crashes here
msg = f"Array {x} has no `at` method and is read-only"
raise ValueError(msg)

Check warning on line 745 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L744-L745

Added lines #L744 - L745 were not covered by tests

return None, x

def get(self, **kwargs: Untyped) -> Untyped:
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
that the output is either a copy or a view; it also allows passing
keyword arguments to the backend.
"""
if kwargs.get("copy") is False:
if is_array_api_obj(self.idx):
# Boolean index. Note that the array API spec
# https://data-apis.org/array-api/latest/API_specification/indexing.html
# does not allow for list, tuple, and tuples of slices plus one or more
# one-dimensional array indices, although many backends support them.
# So this check will encounter a lot of false negatives in real life,
# which can be caught by testing the user code vs. array-api-strict.
msg = "get() with an array index always returns a copy"
raise ValueError(msg)
if is_dask_array(self.x):
msg = "get() on Dask arrays always returns a copy"
raise ValueError(msg)

Check warning on line 766 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L765-L766

Added lines #L765 - L766 were not covered by tests

res, x = self._common("get", _is_update=False, **kwargs)
if res is not None:
return res

Check warning on line 770 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L770

Added line #L770 was not covered by tests
assert x is not None
return x[self.idx]

def set(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = y`` and return the update array"""
res, x = self._common("set", y, **kwargs)
if res is not None:
return res

Check warning on line 778 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L778

Added line #L778 was not covered by tests
assert x is not None
x[self.idx] = y
return x

def _iop(
self,
at_op: str,
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
**kwargs: Untyped,
) -> Array:
"""x[idx] += y or equivalent in-place operation on a subset of x

which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._common(at_op, y, **kwargs)
if res is not None:
return res

Check warning on line 802 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L802

Added line #L802 was not covered by tests
assert x is not None
x[self.idx] = elwise_op(x[self.idx], y)
return x

def add(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] += y`` and return the updated array"""
return self._iop("add", operator.add, y, **kwargs)

def subtract(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] -= y`` and return the updated array"""
return self._iop("subtract", operator.sub, y, **kwargs)

def multiply(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] *= y`` and return the updated array"""
return self._iop("multiply", operator.mul, y, **kwargs)

def divide(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] /= y`` and return the updated array"""
return self._iop("divide", operator.truediv, y, **kwargs)

def power(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] **= y`` and return the updated array"""
return self._iop("power", operator.pow, y, **kwargs)

def min(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("min", xp.minimum, y, **kwargs)

def max(self, y: Array, /, **kwargs: Untyped) -> Array:
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
xp = array_namespace(self.x)
y = xp.asarray(y)
return self._iop("max", xp.maximum, y, **kwargs)
Loading
Loading