Skip to content

TST: xfail_xp_backend #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,454 changes: 1,652 additions & 1,802 deletions pixi.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ xfail_strict = true
filterwarnings = ["error"]
log_cli_level = "INFO"
testpaths = ["tests"]
markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]
markers = [
"skip_xp_backend(library, *, reason=None): Skip test for a specific backend",
"xfail_xp_backend(library, *, reason=None): Xfail test for a specific backend",
]


# Coverage
Expand Down
20 changes: 20 additions & 0 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import math
from types import ModuleType

import pytest

from ._utils._compat import (
array_namespace,
is_cupy_namespace,
Expand Down Expand Up @@ -170,3 +172,21 @@ def xp_assert_close(
np.testing.assert_allclose(
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
)


def xfail(request: pytest.FixtureRequest, reason: str) -> None:
"""
XFAIL the currently running test.

Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
halting it, so that it may result in a XPASS.
xref https://github.com/pandas-dev/pandas/issues/38902

Parameters
----------
request : pytest.FixtureRequest
``request`` argument of the test function.
reason : str
Reason for the expected failure.
"""
request.node.add_marker(pytest.mark.xfail(reason=reason))
27 changes: 16 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from collections.abc import Callable
from contextlib import suppress
from functools import wraps
from functools import partial, wraps
from types import ModuleType
from typing import ParamSpec, TypeVar, cast

import numpy as np
import pytest

from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xfail
from array_api_extra._lib._utils._compat import array_namespace
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Device
Expand All @@ -32,16 +33,20 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
"""
elem = cast(Backend, request.param)

for marker in request.node.iter_markers("skip_xp_backend"):
skip_library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
if not isinstance(skip_library, Backend):
msg = "argument of skip_xp_backend must be a Backend enum"
raise TypeError(msg)
if skip_library == elem:
reason = skip_library.value
with suppress(KeyError):
reason += ":" + cast(str, marker.kwargs["reason"])
pytest.skip(reason=reason)
for marker_name, skip_or_xfail in (
("skip_xp_backend", pytest.skip),
("xfail_xp_backend", partial(xfail, request)),
):
for marker in request.node.iter_markers(marker_name):
library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
if not isinstance(library, Backend):
msg = f"argument of {marker_name} must be a Backend enum"
raise TypeError(msg)
if library == elem:
reason = library.value
with suppress(KeyError):
reason += ":" + cast(str, marker.kwargs["reason"])
skip_or_xfail(reason=reason)

return elem

Expand Down
41 changes: 29 additions & 12 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from array_api_extra import at
from array_api_extra._lib import Backend
from array_api_extra._lib._at import _AtOp
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._testing import xfail, xp_assert_equal
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
from array_api_extra._lib._utils._typing import Array, Index
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -80,10 +80,12 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
@pytest.mark.parametrize(
("kwargs", "expect_copy"),
[
({"copy": True}, True),
({"copy": False}, False),
({"copy": None}, None), # Behavior is backend-specific
({}, None), # Test that the copy parameter defaults to None
pytest.param({"copy": True}, True, id="copy=True"),
pytest.param({"copy": False}, False, id="copy=False"),
# Behavior is backend-specific
pytest.param({"copy": None}, None, id="copy=None"),
# Test that the copy parameter defaults to None
pytest.param({}, None, id="no copy kwarg"),
],
)
@pytest.mark.parametrize(
Expand All @@ -109,10 +111,10 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
True,
True,
marks=(
pytest.mark.skip_xp_backend(
pytest.mark.skip_xp_backend( # test passes when copy=False
Backend.JAX, reason="bool mask update with shaped rhs"
),
pytest.mark.skip_xp_backend(
pytest.mark.xfail_xp_backend(
Backend.DASK, reason="bool mask update with shaped rhs"
),
),
Expand Down Expand Up @@ -177,7 +179,12 @@ def test_alternate_index_syntax():
@pytest.mark.parametrize("bool_mask", [False, True])
@pytest.mark.parametrize("op", list(_AtOp))
def test_incompatible_dtype(
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool
xp: ModuleType,
library: Backend,
request: pytest.FixtureRequest,
op: _AtOp,
copy: bool | None,
bool_mask: bool,
):
"""Test that at() replicates the backend's behaviour for
in-place operations with incompatible dtypes.
Expand Down Expand Up @@ -208,8 +215,8 @@ def test_incompatible_dtype(
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.DASK:
if op in (_AtOp.MIN, _AtOp.MAX):
pytest.xfail(reason="need array-api-compat 1.11")
if op in (_AtOp.MIN, _AtOp.MAX) and bool_mask:
xfail(request, reason="need array-api-compat 1.11")
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
Expand All @@ -234,8 +241,18 @@ def test_bool_mask_nd(xp: ModuleType):
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))


@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere")
@pytest.mark.parametrize("bool_mask", [False, True])
@pytest.mark.parametrize(
"bool_mask",
[
False,
pytest.param(
True,
marks=pytest.mark.xfail_xp_backend(
Backend.DASK, reason="FIXME need scipy's lazywhere"
),
),
],
)
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
x = xp.asarray([math.inf, 1.0, 2.0])
idx = ~xp.isinf(x) if bool_mask else slice(1, None)
Expand Down
69 changes: 45 additions & 24 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
class TestAtLeastND:
def test_0D(self, xp: ModuleType):
x = xp.asarray(1.0)
Expand Down Expand Up @@ -108,12 +108,12 @@ def test_device(self, xp: ModuleType, device: Device):
assert get_device(atleast_nd(x, ndim=2)) == device

def test_xp(self, xp: ModuleType):
x = xp.asarray(1)
y = atleast_nd(x, ndim=0, xp=xp)
xp_assert_equal(y, x)
x = xp.asarray(1.0)
y = atleast_nd(x, ndim=1, xp=xp)
xp_assert_equal(y, xp.ones((1,)))


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestCov:
def test_basic(self, xp: ModuleType):
xp_assert_close(
Expand Down Expand Up @@ -152,16 +152,16 @@ def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
assert get_device(cov(x)) == device

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="explicit xp")
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
def test_xp(self, xp: ModuleType):
xp_assert_close(
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp),
xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64),
)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device")
class TestCreateDiagonal:
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
def test_1d(self, xp: ModuleType):
# from np.diag tests
vals = 100 * xp.arange(5, dtype=xp.float64)
Expand All @@ -177,6 +177,7 @@ def test_1d(self, xp: ModuleType):
xp_assert_equal(create_diagonal(vals, offset=2), b)
xp_assert_equal(create_diagonal(vals, offset=-2), c)

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
@pytest.mark.parametrize("n", range(1, 10))
@pytest.mark.parametrize("offset", range(1, 10))
def test_create_diagonal(self, xp: ModuleType, n: int, offset: int):
Expand All @@ -196,20 +197,22 @@ def test_2d(self, xp: ModuleType):
with pytest.raises(ValueError, match="1-dimensional"):
create_diagonal(xp.asarray([[1]]))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
assert get_device(create_diagonal(x)) == device

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
def test_xp(self, xp: ModuleType):
x = xp.asarray([1, 2])
y = create_diagonal(x, xp=xp)
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
class TestExpandDims:
@pytest.mark.skip_xp_backend(Backend.DASK, reason="tuple index out of range")
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="tuple index out of range")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="tuple index out of range")
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="tuple index out of range")
def test_functionality(self, xp: ModuleType):
def _squeeze_all(b: Array) -> Array:
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
Expand All @@ -225,6 +228,7 @@ def _squeeze_all(b: Array) -> Array:
assert b.shape[axis] == 1
assert _squeeze_all(b).shape == s

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
def test_axis_tuple(self, xp: ModuleType):
a = xp.empty((3, 3, 3))
assert expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3)
Expand Down Expand Up @@ -257,17 +261,19 @@ def test_positive_negative_repeated(self, xp: ModuleType):
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(3, -3))

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
assert get_device(expand_dims(x, axis=0)) == device

@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
def test_xp(self, xp: ModuleType):
x = xp.asarray([1, 2, 3])
y = expand_dims(x, axis=(0, 1, 2), xp=xp)
assert y.shape == (1, 1, 1, 3)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestIsClose:
# FIXME use lazywhere to avoid warnings on inf
@pytest.mark.filterwarnings("ignore:invalid value encountered")
Expand Down Expand Up @@ -402,7 +408,7 @@ def test_none_shape_bool(self, xp: ModuleType):
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
def test_python_scalar(self, xp: ModuleType):
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))
Expand All @@ -425,7 +431,7 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
class TestKron:
def test_basic(self, xp: ModuleType):
# Using 0-dimensional array
Expand Down Expand Up @@ -526,7 +532,7 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange, no device")
class TestPad:
def test_simple(self, xp: ModuleType):
a = xp.arange(1, 4)
Expand Down Expand Up @@ -576,10 +582,24 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
assert padded.shape == (4, 4)


@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
assume_unique = pytest.mark.parametrize(
"assume_unique",
[
True,
pytest.param(
False,
marks=pytest.mark.xfail_xp_backend(
Backend.DASK, reason="NaN-shaped arrays"
),
),
],
)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray()")
class TestSetDiff1D:
@pytest.mark.skip_xp_backend(
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays")
@pytest.mark.xfail_xp_backend(
Backend.TORCH, reason="index_select not implemented for uint32"
)
def test_setdiff1d(self, xp: ModuleType):
Expand Down Expand Up @@ -608,7 +628,7 @@ def test_assume_unique(self, xp: ModuleType):
actual = setdiff1d(x1, x2, assume_unique=True)
xp_assert_equal(actual, expected)

@pytest.mark.parametrize("assume_unique", [True, False])
@assume_unique
@pytest.mark.parametrize("shape1", [(), (1,), (1, 1)])
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
def test_shapes(
Expand All @@ -623,8 +643,8 @@ def test_shapes(
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.empty((0,)))

@assume_unique
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
@pytest.mark.parametrize("assume_unique", [True, False])
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
Expand All @@ -645,21 +665,22 @@ def test_all_python_scalars(self, assume_unique: bool):
with pytest.raises(TypeError, match="Unrecognized"):
setdiff1d(0, 0, assume_unique=assume_unique)

def test_device(self, xp: ModuleType, device: Device):
@assume_unique
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
assert get_device(setdiff1d(x1, x2)) == device
assert get_device(setdiff1d(x1, x2, assume_unique=assume_unique)) == device

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="explicit xp")
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
def test_xp(self, xp: ModuleType):
x1 = xp.asarray([3, 8, 20])
x2 = xp.asarray([2, 3, 4])
expected = xp.asarray([8, 20])
actual = setdiff1d(x1, x2, xp=xp)
actual = setdiff1d(x1, x2, assume_unique=True, xp=xp)
xp_assert_equal(actual, expected)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestSinc:
def test_simple(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0))
Expand Down
Loading