Skip to content

Commit 6266067

Browse files
committed
ENH: new functions isclose and allclose
1 parent 3754e7c commit 6266067

File tree

8 files changed

+332
-4
lines changed

8 files changed

+332
-4
lines changed

docs/api-reference.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
:nosignatures:
77
:toctree: generated
88
9+
allclose
910
at
1011
atleast_nd
1112
cov
1213
create_diagonal
1314
expand_dims
15+
isclose
1416
kron
1517
nunique
1618
pad

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ checks = [
293293
"all", # report on all checks, except the below
294294
"EX01", # most docstrings do not need an example
295295
"SA01", # data-apis/array-api-extra#87
296+
"SA04", # Missing description for See Also cross-reference
296297
"ES01", # most docstrings do not need an extended summary
297298
]
298299
exclude = [ # don't report on objects that match any of these regex

src/array_api_extra/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import pad
3+
from ._delegation import allclose, isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
atleast_nd,
@@ -18,11 +18,13 @@
1818
# pylint: disable=duplicate-code
1919
__all__ = [
2020
"__version__",
21+
"allclose",
2122
"at",
2223
"atleast_nd",
2324
"cov",
2425
"create_diagonal",
2526
"expand_dims",
27+
"isclose",
2628
"kron",
2729
"nunique",
2830
"pad",

src/array_api_extra/_delegation.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ._lib._utils._compat import array_namespace
99
from ._lib._utils._typing import Array
1010

11-
__all__ = ["pad"]
11+
__all__ = ["allclose", "isclose", "pad"]
1212

1313

1414
def _delegate(xp: ModuleType, *backends: Backend) -> bool:
@@ -30,6 +30,144 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
3030
return any(backend.is_namespace(xp) for backend in backends)
3131

3232

33+
def allclose(
34+
a: Array,
35+
b: Array,
36+
*,
37+
rtol: float = 1e-05,
38+
atol: float = 1e-08,
39+
equal_nan: bool = False,
40+
xp: ModuleType | None = None,
41+
) -> Array:
42+
"""
43+
Return True if two arrays are element-wise equal within a tolerance.
44+
45+
This is a simple convenience reduction around `isclose`.
46+
47+
Parameters
48+
----------
49+
a, b : Array
50+
Input arrays to compare.
51+
rtol : array_like, optional
52+
The relative tolerance parameter.
53+
atol : array_like, optional
54+
The absolute tolerance parameter.
55+
equal_nan : bool, optional
56+
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
57+
equal to NaN's in `b` in the output array.
58+
xp : array_namespace, optional
59+
The standard-compatible namespace for `a` and `b`. Default: infer.
60+
61+
Returns
62+
-------
63+
Array
64+
A 0-dimensional boolean array, containing `True` if all `a` is elementwise close
65+
to `b` and `False` otherwise.
66+
67+
See Also
68+
--------
69+
isclose
70+
math.isclose
71+
72+
Notes
73+
-----
74+
If `xp` is a lazy backend (e.g. Dask, JAX), you may not be able to test the result
75+
contents with ``bool(allclose(a, b))`` or ``if allclose(a, b): ...``.
76+
"""
77+
xp = array_namespace(a, b) if xp is None else xp
78+
return xp.all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp))
79+
80+
81+
def isclose(
82+
a: Array,
83+
b: Array,
84+
*,
85+
rtol: float = 1e-05,
86+
atol: float = 1e-08,
87+
equal_nan: bool = False,
88+
xp: ModuleType | None = None,
89+
) -> Array:
90+
"""
91+
Return a boolean array where two arrays are element-wise equal within a tolerance.
92+
93+
The tolerance values are positive, typically very small numbers. The relative
94+
difference `(rtol * abs(b))` and the absolute difference atol are added together to
95+
compare against the absolute difference between a and b.
96+
97+
NaNs are treated as equal if they are in the same place and if equal_nan=True. Infs
98+
are treated as equal if they are in the same place and of the same sign in both
99+
arrays.
100+
101+
Parameters
102+
----------
103+
a, b : Array
104+
Input arrays to compare.
105+
rtol : array_like, optional
106+
The relative tolerance parameter (see Notes).
107+
atol : array_like, optional
108+
The absolute tolerance parameter (see Notes).
109+
equal_nan : bool, optional
110+
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
111+
equal to NaN's in `b` in the output array.
112+
xp : array_namespace, optional
113+
The standard-compatible namespace for `a` and `b`. Default: infer.
114+
115+
Returns
116+
-------
117+
Array
118+
A boolean array of shape broadcasted from `a` and `b`, containing `True` where
119+
``a`` is close to ``b``, and `False` otherwise.
120+
121+
Warnings
122+
--------
123+
The default atol is not appropriate for comparing numbers with magnitudes much
124+
smaller than one ) (see notes).
125+
126+
See Also
127+
--------
128+
allclose
129+
math.isclose
130+
131+
Notes
132+
-----
133+
For finite values, `isclose` uses the following equation to test whether two
134+
floating point values are equivalent::
135+
136+
absolute(a - b) <= (atol + rtol * absolute(b))
137+
138+
Unlike the built-in `math.isclose`, the above equation is not symmetric in a and b,
139+
so that `isclose(a, b)` might be different from `isclose(b, a)` in some rare
140+
cases.
141+
142+
The default value of `atol` is not appropriate when the reference value `b` has
143+
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
144+
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is `True`
145+
with default settings. Be sure to select atol for the use case at hand, especially
146+
for defining the threshold below which a non-zero value in `a` will be considered
147+
"close" to a very small or zero value in `b`.
148+
149+
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
150+
`b` need not have the same shape in order for `isclose(a, b)` to evaluate to
151+
`True`.
152+
153+
`isclose` is not defined for non-numeric data types. `bool` is considered a numeric
154+
data-type for this purpose.
155+
"""
156+
xp = array_namespace(a, b) if xp is None else xp
157+
158+
if _delegate(
159+
xp,
160+
Backend.NUMPY,
161+
Backend.CUPY,
162+
Backend.DASK,
163+
Backend.JAX,
164+
Backend.TORCH,
165+
):
166+
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
167+
168+
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
169+
170+
33171
def pad(
34172
x: Array,
35173
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,41 @@ def expand_dims(
305305
return a
306306

307307

308+
def isclose(
309+
a: Array,
310+
b: Array,
311+
*,
312+
rtol: float = 1e-05,
313+
atol: float = 1e-08,
314+
equal_nan: bool = False,
315+
xp: ModuleType | None = None,
316+
) -> Array: # numpydoc ignore=PR01,RT01
317+
"""See docstring in array_api_extra._delegation."""
318+
xp = array_namespace(a, b) if xp is None else xp
319+
320+
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
321+
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
322+
if a_inexact or b_inexact:
323+
# FIXME: use scipy's lazywhere to suppress warnings on inf
324+
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
325+
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
326+
if equal_nan:
327+
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
328+
return out
329+
330+
if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"):
331+
if atol >= 1 or rtol >= 1:
332+
return xp.ones_like(a == b)
333+
return a == b
334+
335+
# integer types
336+
atol = int(atol)
337+
if rtol == 0:
338+
return xp.abs(a - b) <= atol
339+
nrtol = int(1.0 / rtol)
340+
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)
341+
342+
308343
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
309344
"""
310345
Kronecker product of two arrays.

src/array_api_extra/_lib/_testing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
6161
The expected array (typically hardcoded).
6262
err_msg : str, optional
6363
Error message to display on failure.
64+
65+
See Also
66+
--------
67+
xp_assert_close
68+
np.testing.assert_array_equal
6469
"""
6570
xp = _check_ns_shape_dtype(actual, desired)
6671

@@ -112,6 +117,16 @@ def xp_assert_close(
112117
Absolute tolerance. Default: 0.
113118
err_msg : str, optional
114119
Error message to display on failure.
120+
121+
See Also
122+
--------
123+
xp_assert_equal
124+
allclose
125+
numpy.testing.assert_allclose
126+
127+
Notes
128+
-----
129+
The default `atol` and `rtol` differ from `xpx.allclose`.
115130
"""
116131
xp = _check_ns_shape_dtype(actual, desired)
117132

0 commit comments

Comments
 (0)