Skip to content

Commit 5f4fa99

Browse files
authored
TST: Run all tests on read-only numpy arrays (#92)
1 parent 7f1f4d5 commit 5f4fa99

File tree

5 files changed

+83
-37
lines changed

5 files changed

+83
-37
lines changed

src/array_api_extra/_lib/_backends.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,23 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
1717
Parameters
1818
----------
1919
value : str
20-
String describing the backend.
20+
Name of the backend's module.
2121
is_namespace : Callable[[ModuleType], bool]
2222
Function to check whether an input module is the array namespace
2323
corresponding to the backend.
24-
module_name : str
25-
Name of the backend's module.
2624
"""
2725

28-
ARRAY_API_STRICT = (
29-
"array_api_strict",
30-
_compat.is_array_api_strict_namespace,
31-
"array_api_strict",
32-
)
33-
NUMPY = "numpy", _compat.is_numpy_namespace, "numpy"
34-
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace, "numpy"
35-
CUPY = "cupy", _compat.is_cupy_namespace, "cupy"
36-
TORCH = "torch", _compat.is_torch_namespace, "torch"
37-
DASK_ARRAY = "dask.array", _compat.is_dask_namespace, "dask.array"
38-
SPARSE = "sparse", _compat.is_pydata_sparse_namespace, "sparse"
39-
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace, "jax.numpy"
26+
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
27+
NUMPY = "numpy", _compat.is_numpy_namespace
28+
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
29+
CUPY = "cupy", _compat.is_cupy_namespace
30+
TORCH = "torch", _compat.is_torch_namespace
31+
DASK_ARRAY = "dask.array", _compat.is_dask_namespace
32+
SPARSE = "sparse", _compat.is_pydata_sparse_namespace
33+
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace
4034

4135
def __new__(
42-
cls, value: str, _is_namespace: Callable[[ModuleType], bool], _module_name: str
36+
cls, value: str, _is_namespace: Callable[[ModuleType], bool]
4337
): # numpydoc ignore=GL08
4438
obj = object.__new__(cls)
4539
obj._value_ = value
@@ -49,10 +43,8 @@ def __init__(
4943
self,
5044
value: str, # noqa: ARG002 # pylint: disable=unused-argument
5145
is_namespace: Callable[[ModuleType], bool],
52-
module_name: str,
5346
): # numpydoc ignore=GL08
5447
self.is_namespace = is_namespace
55-
self.module_name = module_name
5648

5749
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5850
"""Pretty-print parameterized test names."""

tests/conftest.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
"""Pytest fixtures."""
22

3+
from collections.abc import Callable
4+
from functools import wraps
35
from types import ModuleType
4-
from typing import cast
6+
from typing import ParamSpec, TypeVar, cast
57

8+
import numpy as np
69
import pytest
710

811
from array_api_extra._lib import Backend
912
from array_api_extra._lib._utils._compat import array_namespace
1013
from array_api_extra._lib._utils._compat import device as get_device
1114
from array_api_extra._lib._utils._typing import Device
1215

16+
T = TypeVar("T")
17+
P = ParamSpec("P")
18+
19+
np_compat = array_namespace(np.empty(0))
20+
1321

1422
@pytest.fixture(params=tuple(Backend))
1523
def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,RT03
@@ -34,6 +42,56 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
3442
return elem
3543

3644

45+
class NumPyReadOnly:
46+
"""
47+
Variant of array_api_compat.numpy producing read-only arrays.
48+
49+
Read-only numpy arrays fail on `__iadd__` etc., whereas read-only libraries such as
50+
JAX and Sparse simply don't define those methods, which makes calls to `+=` fall
51+
back to `__add__`.
52+
53+
Note that this is not a full read-only Array API library. Notably,
54+
`array_namespace(x)` returns array_api_compat.numpy. This is actually the desired
55+
behaviour, so that when a tested function internally calls `xp =
56+
array_namespace(*args) or xp`, it will internally create writeable arrays.
57+
For this reason, tests that explicitly pass xp=xp to the tested functions may
58+
misbehave and should be skipped for NUMPY_READONLY.
59+
"""
60+
61+
def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01
62+
"""Wrap all functions that return arrays to make their output read-only."""
63+
func = getattr(np_compat, name)
64+
if not callable(func) or isinstance(func, type):
65+
return func
66+
return self._wrap(func)
67+
68+
@staticmethod
69+
def _wrap(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
70+
"""Wrap func to make all np.ndarrays it returns read-only."""
71+
72+
def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01
73+
"""Unset the writeable flag in o."""
74+
try:
75+
# Don't use is_numpy_array(o), as it includes np.generic
76+
if isinstance(o, np.ndarray):
77+
o.flags.writeable = False
78+
except TypeError:
79+
# Cannot interpret as a data type
80+
return o
81+
82+
# This works with namedtuples too
83+
if isinstance(o, tuple | list):
84+
return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
85+
86+
return o
87+
88+
@wraps(func)
89+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
90+
return as_readonly(func(*args, **kwargs))
91+
92+
return wrapper
93+
94+
3795
@pytest.fixture
3896
def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
3997
"""
@@ -43,7 +101,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
43101
-------
44102
The current array namespace.
45103
"""
46-
xp = pytest.importorskip(library.module_name)
104+
if library == Backend.NUMPY_READONLY:
105+
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
106+
xp = pytest.importorskip(library.value)
47107
if library == Backend.JAX_NUMPY:
48108
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
49109

tests/test_at.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytest
88
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
99
array_namespace,
10-
is_pydata_sparse_array,
1110
is_writeable_array,
1211
)
1312

@@ -18,14 +17,6 @@
1817
from array_api_extra._lib._utils._typing import Array
1918

2019

21-
@pytest.fixture
22-
def array(library: Backend, xp: ModuleType) -> Array:
23-
x = xp.asarray([10.0, 20.0, 30.0])
24-
if library == Backend.NUMPY_READONLY:
25-
x.flags.writeable = False
26-
return x
27-
28-
2920
@contextmanager
3021
def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
3122
if copy is False and not is_writeable_array(array):
@@ -42,6 +33,9 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
4233
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))
4334

4435

36+
@pytest.mark.skip_xp_backend(
37+
Backend.SPARSE, reason="read-only backend without .at support"
38+
)
4539
@pytest.mark.parametrize(
4640
("kwargs", "expect_copy"),
4741
[
@@ -66,15 +60,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
6660
)
6761
def test_update_ops(
6862
xp: ModuleType,
69-
array: Array,
7063
kwargs: dict[str, bool | None],
7164
expect_copy: bool | None,
7265
op: _AtOp,
7366
arg: float,
7467
expect: list[float],
7568
):
76-
if is_pydata_sparse_array(array):
77-
pytest.skip("at() does not support updates on sparse arrays")
69+
array = xp.asarray([10.0, 20.0, 30.0])
7870

7971
with assert_copy(array, expect_copy):
8072
func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit]

tests/test_funcs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def test_device(self, xp: ModuleType, device: Device):
136136
x = xp.asarray([1, 2, 3], device=device)
137137
assert get_device(cov(x)) == device
138138

139+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY)
139140
def test_xp(self, xp: ModuleType):
140141
xp_assert_close(
141142
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp),
@@ -366,6 +367,7 @@ def test_device(self, xp: ModuleType, device: Device):
366367
x2 = xp.asarray([2, 3, 4], device=device)
367368
assert get_device(setdiff1d(x1, x2)) == device
368369

370+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY)
369371
def test_xp(self, xp: ModuleType):
370372
x1 = xp.asarray([3, 8, 20])
371373
x2 = xp.asarray([2, 3, 4])

tests/test_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from types import ModuleType
22

3-
import numpy as np
43
import pytest
54

65
from array_api_extra._lib import Backend
76
from array_api_extra._lib._testing import xp_assert_equal
87
from array_api_extra._lib._utils._compat import device as get_device
98
from array_api_extra._lib._utils._helpers import in1d
10-
from array_api_extra._lib._utils._typing import Array, Device
9+
from array_api_extra._lib._utils._typing import Device
1110

1211
# mypy: disable-error-code=no-untyped-usage
1312

@@ -16,10 +15,10 @@ class TestIn1D:
1615
@pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="no argsort")
1716
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse, no device")
1817
# cover both code paths
19-
@pytest.mark.parametrize("x2", [np.arange(9), np.arange(15)])
20-
def test_no_invert_assume_unique(self, xp: ModuleType, x2: Array):
18+
@pytest.mark.parametrize("n", [9, 15])
19+
def test_no_invert_assume_unique(self, xp: ModuleType, n: int):
2120
x1 = xp.asarray([3, 8, 20])
22-
x2 = xp.asarray(x2)
21+
x2 = xp.arange(n)
2322
expected = xp.asarray([True, True, False])
2423
actual = in1d(x1, x2)
2524
xp_assert_equal(actual, expected)
@@ -30,6 +29,7 @@ def test_device(self, xp: ModuleType, device: Device):
3029
x2 = xp.asarray([2, 3, 4], device=device)
3130
assert get_device(in1d(x1, x2)) == device
3231

32+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY)
3333
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device")
3434
def test_xp(self, xp: ModuleType):
3535
x1 = xp.asarray([1, 6])

0 commit comments

Comments
 (0)