diff --git a/array-api b/array-api index 02fa9237..c5808f2b 160000 --- a/array-api +++ b/array-api @@ -1 +1 @@ -Subproject commit 02fa9237eab3258120778baec12cd38cfd309ee3 +Subproject commit c5808f2b173ea52d813c450bec7b1beaf2973299 diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 5a96b27f..80e5b597 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -198,7 +198,7 @@ def assert_0d_equals( func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw ): msg = ( - f"{out_repr}={out_val}, should be {x_repr}={x_val} " + f"{out_repr}={out_val}, but should be {x_repr}={x_val} " f"[{func_name}({fmt_kw(kw)})]" ) if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index b8979298..e6580863 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,3 +1,13 @@ +""" +Tests for special cases. + +Most test cases for special casing are built on runtime via the parametrized +tests test_unary/test_binary/test_iop. Most of this file consists of utility +classes and functions, all bought together to create the test cases (pytest +params), to finally be run through generalised test logic. + +TODO: test integer arrays for relevant special cases +""" # We use __future__ for forward reference type hints - this will work for even py3.8.0 # See https://stackoverflow.com/a/33533514/5193926 from __future__ import annotations @@ -32,13 +42,6 @@ pytestmark = pytest.mark.ci -# The special case test casess are built on runtime via the parametrized -# test_unary and test_binary functions. Most of this file consists of utility -# classes and functions, all bought together to create the test cases (pytest -# params), to finally be run through the general test logic of either test_unary -# or test_binary. - - UnaryCheck = Callable[[float], bool] BinaryCheck = Callable[[float, float], bool] @@ -170,24 +173,6 @@ def parse_value(value_str: str) -> float: r_approx_value = re.compile( rf"an implementation-dependent approximation to {r_code.pattern}" ) - - -def parse_inline_code(inline_code: str) -> float: - """ - Parses a Sphinx code string to return a float, e.g. - - >>> parse_value('``0``') - 0. - >>> parse_value('``NaN``') - float('nan') - - """ - if m := r_code.match(inline_code): - return parse_value(m.group(1)) - else: - raise ParseError(inline_code) - - r_not = re.compile("not (.+)") r_equal_to = re.compile(f"equal to {r_code.pattern}") r_array_element = re.compile(r"``([+-]?)x([12])_i``") @@ -526,6 +511,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(<{self}>)" +r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters") +r_case = re.compile(r"\s+-\s*(.*)\.") + + class UnaryCond(Protocol): def __call__(self, i: float) -> bool: ... @@ -546,12 +535,34 @@ class UnaryCase(Case): r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") +r_already_int_case = re.compile( + "If ``x_i`` is already integer-valued, the result is ``x_i``" +) r_even_round_halves_case = re.compile( "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) +def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + """ + Returns a strategy that generates float-casted integers within the bounds of dtype. + """ + for k in kw.keys(): + # sanity check + assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] + m, M = dh.dtype_ranges[dtype] + if "min_value" in kw.keys(): + m = kw["min_value"] + if "exclude_min" in kw.keys(): + m += 1 + if "max_value" in kw.keys(): + M = kw["max_value"] + if "exclude_max" in kw.keys(): + M -= 1 + return st.integers(math.ceil(m), math.floor(M)).map(float) + + def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: """ Returns a strategy that generates floats that end with .5 and are within the @@ -568,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: ) +already_int_case = UnaryCase( + cond_expr="x_i.is_integer()", + cond=lambda i: i.is_integer(), + cond_from_dtype=integers_from_dtype, + result_expr="x_i", + check_result=lambda i, result: i == result, +) even_round_halves_case = UnaryCase( cond_expr="modf(i)[0] == 0.5", cond=lambda i: math.modf(i)[0] == 0.5, @@ -586,7 +604,7 @@ def check_result(i: float, result: float) -> bool: return check_result -def parse_unary_docstring(docstring: str) -> List[UnaryCase]: +def parse_unary_case_block(case_block: str) -> List[UnaryCase]: """ Parses a Sphinx-formatted docstring of a unary function to return a list of codified unary cases, e.g. @@ -616,7 +634,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: ... an array containing the square root of each element in ``x`` ... ''' ... - >>> unary_cases = parse_unary_docstring(sqrt.__doc__) + >>> case_block = r_case_block.search(sqrt.__doc__).group(1) + >>> unary_cases = parse_unary_case_block(case_block) >>> for case in unary_cases: ... print(repr(case)) UnaryCase( NaN>) @@ -631,19 +650,14 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: True """ - - match = r_special_cases.search(docstring) - if match is None: - return [] - lines = match.group(1).split("\n")[:-1] cases = [] - for line in lines: - if m := r_case.match(line): - case = m.group(1) - else: - warn(f"line not machine-readable: '{line}'") - continue - if m := r_unary_case.search(case): + for case_m in r_case.finditer(case_block): + case_str = case_m.group(1) + if m := r_already_int_case.search(case_str): + cases.append(already_int_case) + elif m := r_even_round_halves_case.search(case_str): + cases.append(even_round_halves_case) + elif m := r_unary_case.search(case_str): try: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) _check_result, result_expr = parse_result(m.group(2)) @@ -662,11 +676,9 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: check_result=check_result, ) cases.append(case) - elif m := r_even_round_halves_case.search(case): - cases.append(even_round_halves_case) else: - if not r_remaining_case.search(case): - warn(f"case not machine-readable: '{case}'") + if not r_remaining_case.search(case_str): + warn(f"case not machine-readable: '{case_str}'") return cases @@ -690,12 +702,6 @@ class BinaryCase(Case): check_result: BinaryResultCheck -r_special_cases = re.compile( - r"\*\*Special [Cc]ases\*\*(?:\n.*)+" - r"For floating-point operands,\n+" - r"((?:\s*-\s*.*\n)+)" -) -r_case = re.compile(r"\s+-\s*(.*)\.\n?") r_binary_case = re.compile("If (.+), the result (.+)") r_remaining_case = re.compile("In the remaining cases.+") r_cond_sep = re.compile(r"(? bool: return check_result -def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: - """ - Returns a strategy that generates float-casted integers within the bounds of dtype. - """ - for k in kw.keys(): - # sanity check - assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] - m, M = dh.dtype_ranges[dtype] - if "min_value" in kw.keys(): - m = kw["min_value"] - if "exclude_min" in kw.keys(): - m += 1 - if "max_value" in kw.keys(): - M = kw["max_value"] - if "exclude_max" in kw.keys(): - M -= 1 - return st.integers(math.ceil(m), math.floor(M)).map(float) - - def parse_binary_case(case_str: str) -> BinaryCase: """ Parses a Sphinx-formatted binary case string to return codified binary cases, e.g. @@ -880,8 +867,7 @@ def parse_binary_case(case_str: str) -> BinaryCase: """ case_m = r_binary_case.match(case_str) - if case_m is None: - raise ParseError(case_str) + assert case_m is not None # sanity check cond_strs = r_cond_sep.split(case_m.group(1)) partial_conds = [] @@ -1078,7 +1064,7 @@ def cond(i1: float, i2: float) -> bool: r_redundant_case = re.compile("result.+determined by the rule already stated above") -def parse_binary_docstring(docstring: str) -> List[BinaryCase]: +def parse_binary_case_block(case_block: str) -> List[BinaryCase]: """ Parses a Sphinx-formatted docstring of a binary function to return a list of codified binary cases, e.g. @@ -1108,7 +1094,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: ... an array containing the results ... ''' ... - >>> binary_cases = parse_binary_docstring(logaddexp.__doc__) + >>> case_block = r_case_block.search(logaddexp.__doc__).group(1) + >>> binary_cases = parse_binary_case_block(case_block) >>> for case in binary_cases: ... print(repr(case)) BinaryCase( NaN>) @@ -1116,21 +1103,12 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: BinaryCase( +infinity>) """ - - match = r_special_cases.search(docstring) - if match is None: - return [] - lines = match.group(1).split("\n")[:-1] cases = [] - for line in lines: - if m := r_case.match(line): - case_str = m.group(1) - else: - warn(f"line not machine-readable: '{line}'") - continue + for case_m in r_case.finditer(case_block): + case_str = case_m.group(1) if r_redundant_case.search(case_str): continue - if m := r_binary_case.match(case_str): + if r_binary_case.match(case_str): try: case = parse_binary_case(case_str) cases.append(case) @@ -1150,6 +1128,10 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: if stub.__doc__ is None: warn(f"{stub.__name__}() stub has no docstring") continue + if m := r_case_block.search(stub.__doc__): + case_block = m.group(1) + else: + continue marks = [] try: func = getattr(xp, stub.__name__) @@ -1164,40 +1146,44 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: warn(f"{func=} has no parameters") continue if param_names[0] == "x": - if cases := parse_unary_docstring(stub.__doc__): - func_name_to_func = {stub.__name__: func} + if cases := parse_unary_case_block(case_block): + name_to_func = {stub.__name__: func} if stub.__name__ in func_to_op.keys(): op_name = func_to_op[stub.__name__] op = getattr(operator, op_name) - func_name_to_func[op_name] = op - for func_name, func in func_name_to_func.items(): + name_to_func[op_name] = op + for func_name, func in name_to_func.items(): for case in cases: id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" p = pytest.param(func_name, func, case, id=id_) unary_params.append(p) + else: + warn(f"Special cases found for {stub.__name__} but none were parsed") continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_docstring(stub.__doc__): - func_name_to_func = {stub.__name__: func} + if cases := parse_binary_case_block(case_block): + name_to_func = {stub.__name__: func} if stub.__name__ in func_to_op.keys(): op_name = func_to_op[stub.__name__] op = getattr(operator, op_name) - func_name_to_func[op_name] = op - # We collect inplaceoperator test cases seperately + name_to_func[op_name] = op + # We collect inplace operator test cases seperately iop_name = "__i" + op_name[2:] iop = getattr(operator, iop_name) for case in cases: id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}" p = pytest.param(iop_name, iop, case, id=id_) iop_params.append(p) - for func_name, func in func_name_to_func.items(): + for func_name, func in name_to_func.items(): for case in cases: id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" p = pytest.param(func_name, func, case, id=id_) binary_params.append(p) + else: + warn(f"Special cases found for {stub.__name__} but none were parsed") continue else: warn( @@ -1206,7 +1192,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: ) -# test_unary and test_binary naively generate arrays, i.e. arrays that might not +# test_{unary/binary/iop} naively generate arrays, i.e. arrays that might not # meet the condition that is being test. We then forcibly make the array meet # the condition by picking a random index to insert an acceptable element. # @@ -1343,3 +1329,46 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): ) break assume(good_example) + + +@pytest.mark.parametrize( + "func_name, expected", + [ + ("mean", float("nan")), + ("prod", 1), + ("std", float("nan")), + ("sum", 0), + ("var", float("nan")), + ], + ids=["mean", "prod", "std", "sum", "var"], +) +def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get expected + func = getattr(xp, func_name) + out = func(xp.asarray([], dtype=dh.default_float)) + ph.assert_shape(func_name, out.shape, ()) # sanity check + msg = f"{out=!r}, but should be {expected}" + if math.isnan(expected): + assert xp.isnan(out), msg + else: + assert out == expected, msg + + +@pytest.mark.parametrize( + "func_name", [f.__name__ for f in category_to_funcs["statistical"]] +) +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_nan_propagation(func_name, x, data): + func = getattr(xp, func_name) + set_idx = data.draw( + xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" + ) + x[set_idx] = float("nan") + note(f"{x=}") + + out = func(x) + + ph.assert_shape(func_name, out.shape, ()) # sanity check + assert xp.isnan(out), f"{out=!r}, but should be NaN"