Skip to content

Commit 9d1f4da

Browse files
committed
Extend note on refimpl utils
1 parent dfda4f5 commit 9d1f4da

File tree

1 file changed

+38
-22
lines changed

1 file changed

+38
-22
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

+38-22
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
"""
2-
Tests for elementwise functions
3-
4-
https://data-apis.github.io/array-api/latest/API_specification/elementwise_functions.html
5-
6-
This tests behavior that is explicitly mentioned in the spec. Note that the
7-
spec does not make any accuracy requirements for functions, so this does not
8-
test that. Tests for the special cases are generated and tested separately in
9-
special_cases/
10-
"""
11-
121
import math
132
import operator
143
from enum import Enum, auto
@@ -41,13 +30,6 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
4130
return xps.boolean_dtypes() | all_integer_dtypes()
4231

4332

44-
def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
45-
"""Wraps math.isclose with more generous defaults."""
46-
if not (math.isfinite(a) and math.isfinite(b)):
47-
raise ValueError(f"{a=} and {b=}, but input must be finite")
48-
return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
49-
50-
5133
def mock_int_dtype(n: int, dtype: DataType) -> int:
5234
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
5335
nbits = dh.dtype_nbits[dtype]
@@ -60,6 +42,40 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
6042
return n
6143

6244

45+
# This module tests elementwise functions/operators against a reference
46+
# implementation. We iterate through the input array(s) and resulting array,
47+
# casting the indexed arrays to Python scalars and calculating the expected
48+
# output with `refimpl` function.
49+
#
50+
# This is finicky to refactor, but possible and ultimately worthwhile - hence
51+
# why these *_assert_again_refimpl() utilities exist.
52+
#
53+
# Values which are special-cased are generated and passed, but are filtered by
54+
# the `filter_` callable before they can be asserted against `refimpl`. We
55+
# automatically generate tests for special cases in the special_cases/ dir. We
56+
# still pass them here so as to ensure their presence doesn't affect the outputs
57+
# respective to non-special-cased elements.
58+
#
59+
# By default, results are casted to scalars the same way that the inputs are.
60+
# You can specify a cast via `res_stype, i.e. when a function accepts numerical
61+
# inputs but returns boolean arrays.
62+
#
63+
# By default, floating-point functions/methods are loosely asserted against. Use
64+
# `strict_check=True` when they should be strictly asserted against, i.e.
65+
# when a function should return intergrals.
66+
67+
68+
def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
69+
"""Wraps math.isclose with very generous defaults.
70+
71+
This is useful for many floating-point operations where the spec does not
72+
make accuracy requirements.
73+
"""
74+
if not (math.isfinite(a) and math.isfinite(b)):
75+
raise ValueError(f"{a=} and {b=}, but input must be finite")
76+
return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
77+
78+
6379
def default_filter(s: Scalar) -> bool:
6480
"""Returns False when s is a non-finite or a signed zero.
6581
@@ -168,14 +184,14 @@ def binary_assert_against_refimpl(
168184
# elementwise methods. We do this by parametrizing a generalised test method
169185
# with every relevant method and operator.
170186
#
171-
# Notable arguments in the parameter:
187+
# Notable arguments in the parameter's context object:
172188
# - The function object, which for operator test cases is a wrapper that allows
173189
# test logic to be generalised.
174190
# - The argument strategies, which can be used to draw arguments for the test
175191
# case. They may require additional filtering for certain test cases.
176-
# - right_is_scalar (binary parameters), which denotes if the right argument is
177-
# a scalar in a test case. This can be used to appropiately adjust draw
178-
# filtering and test logic.
192+
# - right_is_scalar (binary parameters only), which denotes if the right
193+
# argument is a scalar in a test case. This can be used to appropiately adjust
194+
# draw filtering and test logic.
179195

180196

181197
func_to_op = {v: k for k, v in dh.op_to_func.items()}

0 commit comments

Comments
 (0)