From b8bb10f31430f1ff8f6e1297c95cf50654fe16b5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 12 May 2022 10:33:45 +0100 Subject: [PATCH 01/10] `func_name_to_func` -> `name_to_func` --- array_api_tests/test_special_cases.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index b8979298..43f158dd 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1165,12 +1165,12 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: continue if param_names[0] == "x": if cases := parse_unary_docstring(stub.__doc__): - func_name_to_func = {stub.__name__: func} + name_to_func = {stub.__name__: func} if stub.__name__ in func_to_op.keys(): op_name = func_to_op[stub.__name__] op = getattr(operator, op_name) - func_name_to_func[op_name] = op - for func_name, func in func_name_to_func.items(): + name_to_func[op_name] = op + for func_name, func in name_to_func.items(): for case in cases: id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" p = pytest.param(func_name, func, case, id=id_) @@ -1181,11 +1181,11 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: continue if param_names[0] == "x1" and param_names[1] == "x2": if cases := parse_binary_docstring(stub.__doc__): - func_name_to_func = {stub.__name__: func} + name_to_func = {stub.__name__: func} if stub.__name__ in func_to_op.keys(): op_name = func_to_op[stub.__name__] op = getattr(operator, op_name) - func_name_to_func[op_name] = op + name_to_func[op_name] = op # We collect inplaceoperator test cases seperately iop_name = "__i" + op_name[2:] iop = getattr(operator, iop_name) @@ -1193,7 +1193,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}" p = pytest.param(iop_name, iop, case, id=id_) iop_params.append(p) - for func_name, func in func_name_to_func.items(): + for func_name, func in name_to_func.items(): for case in cases: id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" p = pytest.param(func_name, func, case, id=id_) From 0acee7c66c2f6a3c238cadafec85dfb5e5964a0c Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 12:18:51 +0100 Subject: [PATCH 02/10] Generously capture all special cases from stubs --- array_api_tests/test_special_cases.py | 155 +++++++++++++------------- 1 file changed, 75 insertions(+), 80 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 43f158dd..d1523b44 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -526,6 +526,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(<{self}>)" +r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters") +r_case = re.compile(r"\s+-\s*(.*)\.") + + class UnaryCond(Protocol): def __call__(self, i: float) -> bool: ... @@ -586,7 +590,7 @@ def check_result(i: float, result: float) -> bool: return check_result -def parse_unary_docstring(docstring: str) -> List[UnaryCase]: +def parse_unary_case_block(case_block: str) -> List[UnaryCase]: """ Parses a Sphinx-formatted docstring of a unary function to return a list of codified unary cases, e.g. @@ -616,7 +620,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: ... an array containing the square root of each element in ``x`` ... ''' ... - >>> unary_cases = parse_unary_docstring(sqrt.__doc__) + >>> case_block = r_case_block.match(sqrt.__doc__).group(1) + >>> unary_cases = parse_unary_case_block(case_block) >>> for case in unary_cases: ... print(repr(case)) UnaryCase( NaN>) @@ -631,19 +636,10 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: True """ - - match = r_special_cases.search(docstring) - if match is None: - return [] - lines = match.group(1).split("\n")[:-1] cases = [] - for line in lines: - if m := r_case.match(line): - case = m.group(1) - else: - warn(f"line not machine-readable: '{line}'") - continue - if m := r_unary_case.search(case): + for case_m in r_case.finditer(case_block): + case_str = case_m.group(1) + if m := r_unary_case.search(case_str): try: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) _check_result, result_expr = parse_result(m.group(2)) @@ -662,11 +658,11 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: check_result=check_result, ) cases.append(case) - elif m := r_even_round_halves_case.search(case): + elif m := r_even_round_halves_case.search(case_str): cases.append(even_round_halves_case) else: - if not r_remaining_case.search(case): - warn(f"case not machine-readable: '{case}'") + if not r_remaining_case.search(case_str): + warn(f"case not machine-readable: '{case_str}'") return cases @@ -690,12 +686,6 @@ class BinaryCase(Case): check_result: BinaryResultCheck -r_special_cases = re.compile( - r"\*\*Special [Cc]ases\*\*(?:\n.*)+" - r"For floating-point operands,\n+" - r"((?:\s*-\s*.*\n)+)" -) -r_case = re.compile(r"\s+-\s*(.*)\.\n?") r_binary_case = re.compile("If (.+), the result (.+)") r_remaining_case = re.compile("In the remaining cases.+") r_cond_sep = re.compile(r"(? BinaryCase: """ case_m = r_binary_case.match(case_str) - if case_m is None: - raise ParseError(case_str) + assert case_m is not None # sanity check cond_strs = r_cond_sep.split(case_m.group(1)) partial_conds = [] @@ -1078,7 +1067,7 @@ def cond(i1: float, i2: float) -> bool: r_redundant_case = re.compile("result.+determined by the rule already stated above") -def parse_binary_docstring(docstring: str) -> List[BinaryCase]: +def parse_binary_case_block(case_block: str) -> List[BinaryCase]: """ Parses a Sphinx-formatted docstring of a binary function to return a list of codified binary cases, e.g. @@ -1108,7 +1097,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: ... an array containing the results ... ''' ... - >>> binary_cases = parse_binary_docstring(logaddexp.__doc__) + >>> case_block = r_case_block.match(logaddexp.__doc__).group(1) + >>> binary_cases = parse_binary_case_block(case_block) >>> for case in binary_cases: ... print(repr(case)) BinaryCase( NaN>) @@ -1116,21 +1106,12 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: BinaryCase( +infinity>) """ - - match = r_special_cases.search(docstring) - if match is None: - return [] - lines = match.group(1).split("\n")[:-1] cases = [] - for line in lines: - if m := r_case.match(line): - case_str = m.group(1) - else: - warn(f"line not machine-readable: '{line}'") - continue + for case_m in r_case.finditer(case_block): + case_str = case_m.group(1) if r_redundant_case.search(case_str): continue - if m := r_binary_case.match(case_str): + if r_binary_case.match(case_str): try: case = parse_binary_case(case_str) cases.append(case) @@ -1142,14 +1123,19 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: return cases +category_stub_pairs = [(c, s) for c, stubs in category_to_funcs.items() for s in stubs] unary_params = [] binary_params = [] iop_params = [] func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} -for stub in category_to_funcs["elementwise"]: +for category, stub in category_stub_pairs: if stub.__doc__ is None: warn(f"{stub.__name__}() stub has no docstring") continue + if m := r_case_block.search(stub.__doc__): + case_block = m.group(1) + else: + continue marks = [] try: func = getattr(xp, stub.__name__) @@ -1163,47 +1149,56 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: if len(sig.parameters) == 0: warn(f"{func=} has no parameters") continue - if param_names[0] == "x": - if cases := parse_unary_docstring(stub.__doc__): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] - op = getattr(operator, op_name) - name_to_func[op_name] = op - for func_name, func in name_to_func.items(): - for case in cases: - id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(func_name, func, case, id=id_) - unary_params.append(p) - continue - if len(sig.parameters) == 1: - warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") - continue - if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_docstring(stub.__doc__): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] - op = getattr(operator, op_name) - name_to_func[op_name] = op - # We collect inplaceoperator test cases seperately - iop_name = "__i" + op_name[2:] - iop = getattr(operator, iop_name) - for case in cases: - id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(iop_name, iop, case, id=id_) - iop_params.append(p) - for func_name, func in name_to_func.items(): - for case in cases: - id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(func_name, func, case, id=id_) - binary_params.append(p) - continue + if category == "elementwise": + if param_names[0] == "x": + if cases := parse_unary_case_block(case_block): + name_to_func = {stub.__name__: func} + if stub.__name__ in func_to_op.keys(): + op_name = func_to_op[stub.__name__] + op = getattr(operator, op_name) + name_to_func[op_name] = op + for func_name, func in name_to_func.items(): + for case in cases: + id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(func_name, func, case, id=id_) + unary_params.append(p) + else: + warn("TODO") + continue + if len(sig.parameters) == 1: + warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") + continue + if param_names[0] == "x1" and param_names[1] == "x2": + if cases := parse_binary_case_block(case_block): + name_to_func = {stub.__name__: func} + if stub.__name__ in func_to_op.keys(): + op_name = func_to_op[stub.__name__] + op = getattr(operator, op_name) + name_to_func[op_name] = op + # We collect inplace operator test cases seperately + iop_name = "__i" + op_name[2:] + iop = getattr(operator, iop_name) + for case in cases: + id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(iop_name, iop, case, id=id_) + iop_params.append(p) + for func_name, func in name_to_func.items(): + for case in cases: + id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(func_name, func, case, id=id_) + binary_params.append(p) + else: + warn("TODO") + continue + else: + warn( + f"{func=} starts with two parameters '{param_names[0]}' and " + f"'{param_names[1]}', which are not named 'x1' and 'x2'" + ) + elif category == "statistical": + pass # TODO else: - warn( - f"{func=} starts with two parameters '{param_names[0]}' and " - f"'{param_names[1]}', which are not named 'x1' and 'x2'" - ) + warn("TODO") # test_unary and test_binary naively generate arrays, i.e. arrays that might not From 8eeae4d040202272734b68c3d92c722a70ffc210 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 13:41:18 +0100 Subject: [PATCH 03/10] Basic module-level doc for `test_special_cases.py` --- array_api_tests/test_special_cases.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index d1523b44..f17f69d9 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,3 +1,14 @@ +""" +Tests for special cases. + +The test cases for special casing are built on runtime via the parametrized +test_unary and test_binary functions. Most of this file consists of utility +classes and functions, all bought together to create the test cases (pytest +params), to finally be run through the general test logic of either test_unary +or test_binary. + +TODO: test integer arrays for relevant special cases +""" # We use __future__ for forward reference type hints - this will work for even py3.8.0 # See https://stackoverflow.com/a/33533514/5193926 from __future__ import annotations @@ -32,13 +43,6 @@ pytestmark = pytest.mark.ci -# The special case test casess are built on runtime via the parametrized -# test_unary and test_binary functions. Most of this file consists of utility -# classes and functions, all bought together to create the test cases (pytest -# params), to finally be run through the general test logic of either test_unary -# or test_binary. - - UnaryCheck = Callable[[float], bool] BinaryCheck = Callable[[float, float], bool] From a945ca4673a939e9737a242a12dbc51ee943d5e6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 14:55:40 +0100 Subject: [PATCH 04/10] Test NaN propagation special cases --- array_api_tests/pytest_helpers.py | 2 +- array_api_tests/test_special_cases.py | 115 +++++++++++++++----------- 2 files changed, 66 insertions(+), 51 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 5a96b27f..80e5b597 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -198,7 +198,7 @@ def assert_0d_equals( func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw ): msg = ( - f"{out_repr}={out_val}, should be {x_repr}={x_val} " + f"{out_repr}={out_val}, but should be {x_repr}={x_val} " f"[{func_name}({fmt_kw(kw)})]" ) if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index f17f69d9..cfb8659e 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1127,12 +1127,11 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: return cases -category_stub_pairs = [(c, s) for c, stubs in category_to_funcs.items() for s in stubs] unary_params = [] binary_params = [] iop_params = [] func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} -for category, stub in category_stub_pairs: +for stub in category_to_funcs["elementwise"]: if stub.__doc__ is None: warn(f"{stub.__name__}() stub has no docstring") continue @@ -1153,56 +1152,51 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: if len(sig.parameters) == 0: warn(f"{func=} has no parameters") continue - if category == "elementwise": - if param_names[0] == "x": - if cases := parse_unary_case_block(case_block): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] - op = getattr(operator, op_name) - name_to_func[op_name] = op - for func_name, func in name_to_func.items(): - for case in cases: - id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(func_name, func, case, id=id_) - unary_params.append(p) - else: - warn("TODO") - continue - if len(sig.parameters) == 1: - warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") - continue - if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_case_block(case_block): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] - op = getattr(operator, op_name) - name_to_func[op_name] = op - # We collect inplace operator test cases seperately - iop_name = "__i" + op_name[2:] - iop = getattr(operator, iop_name) - for case in cases: - id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(iop_name, iop, case, id=id_) - iop_params.append(p) - for func_name, func in name_to_func.items(): - for case in cases: - id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(func_name, func, case, id=id_) - binary_params.append(p) - else: - warn("TODO") - continue + if param_names[0] == "x": + if cases := parse_unary_case_block(case_block): + name_to_func = {stub.__name__: func} + if stub.__name__ in func_to_op.keys(): + op_name = func_to_op[stub.__name__] + op = getattr(operator, op_name) + name_to_func[op_name] = op + for func_name, func in name_to_func.items(): + for case in cases: + id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(func_name, func, case, id=id_) + unary_params.append(p) else: - warn( - f"{func=} starts with two parameters '{param_names[0]}' and " - f"'{param_names[1]}', which are not named 'x1' and 'x2'" - ) - elif category == "statistical": - pass # TODO + warn("TODO") + continue + if len(sig.parameters) == 1: + warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") + continue + if param_names[0] == "x1" and param_names[1] == "x2": + if cases := parse_binary_case_block(case_block): + name_to_func = {stub.__name__: func} + if stub.__name__ in func_to_op.keys(): + op_name = func_to_op[stub.__name__] + op = getattr(operator, op_name) + name_to_func[op_name] = op + # We collect inplace operator test cases seperately + iop_name = "__i" + op_name[2:] + iop = getattr(operator, iop_name) + for case in cases: + id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(iop_name, iop, case, id=id_) + iop_params.append(p) + for func_name, func in name_to_func.items(): + for case in cases: + id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(func_name, func, case, id=id_) + binary_params.append(p) + else: + warn("TODO") + continue else: - warn("TODO") + warn( + f"{func=} starts with two parameters '{param_names[0]}' and " + f"'{param_names[1]}', which are not named 'x1' and 'x2'" + ) # test_unary and test_binary naively generate arrays, i.e. arrays that might not @@ -1342,3 +1336,24 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): ) break assume(good_example) + + +@pytest.mark.parametrize( + "func_name", [f.__name__ for f in category_to_funcs["statistical"]] +) +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_nan_propagation(func_name, x, data): + func = getattr(xp, func_name) + set_idx = data.draw( + xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" + ) + x[set_idx] = float("nan") + note(f"{x=}") + + out = func(x) + + ph.assert_shape(func_name, out.shape, ()) # sanity check + assert xp.isnan(out), f"{out=!r}, but should be NaN" From 6f05d364e1fb0f48b0143ac1db7d7158da3d02ce Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 15:14:57 +0100 Subject: [PATCH 05/10] Test special cased empty array behaviour --- array_api_tests/test_special_cases.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index cfb8659e..a4b7c30b 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1338,6 +1338,28 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): assume(good_example) +@pytest.mark.parametrize( + "func_name, expected", + [ + ("mean", float("nan")), + ("prod", 1), + ("std", float("nan")), + ("sum", 0), + ("var", float("nan")), + ], + ids=["mean", "prod", "std", "sum", "var"], +) +def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get expected + func = getattr(xp, func_name) + out = func(xp.asarray([], dtype=dh.default_float)) + ph.assert_shape(func_name, out.shape, ()) # sanity check + msg = f"{out=!r}, but should be {expected}" + if math.isnan(expected): + assert xp.isnan(out), msg + else: + assert out == expected, msg + + @pytest.mark.parametrize( "func_name", [f.__name__ for f in category_to_funcs["statistical"]] ) From c9f2c40a2c2f887e7a99ff608e02cc90fba703c1 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 17:49:49 +0100 Subject: [PATCH 06/10] Update wording on docs and warnings --- array_api_tests/test_special_cases.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index a4b7c30b..5d1abd31 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,11 +1,10 @@ """ Tests for special cases. -The test cases for special casing are built on runtime via the parametrized -test_unary and test_binary functions. Most of this file consists of utility +Most test cases for special casing are built on runtime via the parametrized +tests test_unary/test_binary/test_iop. Most of this file consists of utility classes and functions, all bought together to create the test cases (pytest -params), to finally be run through the general test logic of either test_unary -or test_binary. +params), to finally be run through generalised test logic. TODO: test integer arrays for relevant special cases """ @@ -1165,7 +1164,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: p = pytest.param(func_name, func, case, id=id_) unary_params.append(p) else: - warn("TODO") + warn(f"Special cases found for {stub.__name__} but none were parsed") continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") @@ -1190,7 +1189,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: p = pytest.param(func_name, func, case, id=id_) binary_params.append(p) else: - warn("TODO") + warn(f"Special cases found for {stub.__name__} but none were parsed") continue else: warn( @@ -1199,7 +1198,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: ) -# test_unary and test_binary naively generate arrays, i.e. arrays that might not +# test_{unary/binary/iop} naively generate arrays, i.e. arrays that might not # meet the condition that is being test. We then forcibly make the array meet # the condition by picking a random index to insert an acceptable element. # From ba6b87cdfde3f0b1d0066bf6259edc53316dc1c7 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 18:15:17 +0100 Subject: [PATCH 07/10] Remove unused `parse_inline_code` util --- array_api_tests/test_special_cases.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 5d1abd31..53a4bc64 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -173,24 +173,6 @@ def parse_value(value_str: str) -> float: r_approx_value = re.compile( rf"an implementation-dependent approximation to {r_code.pattern}" ) - - -def parse_inline_code(inline_code: str) -> float: - """ - Parses a Sphinx code string to return a float, e.g. - - >>> parse_value('``0``') - 0. - >>> parse_value('``NaN``') - float('nan') - - """ - if m := r_code.match(inline_code): - return parse_value(m.group(1)) - else: - raise ParseError(inline_code) - - r_not = re.compile("not (.+)") r_equal_to = re.compile(f"equal to {r_code.pattern}") r_array_element = re.compile(r"``([+-]?)x([12])_i``") From 2e5322b5b73efdc20bfe3e12e25ce2b972b8def7 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 18:51:16 +0100 Subject: [PATCH 08/10] Parse "already integer-valued" special cases --- array_api_tests/test_special_cases.py | 56 ++++++++++++++++----------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 53a4bc64..2c14fdfa 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -535,12 +535,34 @@ class UnaryCase(Case): r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") +r_already_int_case = re.compile( + "If ``x_i`` is already integer-valued, the result is ``x_i``" +) r_even_round_halves_case = re.compile( "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) +def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + """ + Returns a strategy that generates float-casted integers within the bounds of dtype. + """ + for k in kw.keys(): + # sanity check + assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] + m, M = dh.dtype_ranges[dtype] + if "min_value" in kw.keys(): + m = kw["min_value"] + if "exclude_min" in kw.keys(): + m += 1 + if "max_value" in kw.keys(): + M = kw["max_value"] + if "exclude_max" in kw.keys(): + M -= 1 + return st.integers(math.ceil(m), math.floor(M)).map(float) + + def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: """ Returns a strategy that generates floats that end with .5 and are within the @@ -557,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: ) +already_int_case = UnaryCase( + cond_expr="x_i.is_integer()", + cond=lambda i: i.is_integer(), + cond_from_dtype=integers_from_dtype, + result_expr="x_i", + check_result=lambda i, result: i == result, +) even_round_halves_case = UnaryCase( cond_expr="modf(i)[0] == 0.5", cond=lambda i: math.modf(i)[0] == 0.5, @@ -624,7 +653,11 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) - if m := r_unary_case.search(case_str): + if m := r_already_int_case.search(case_str): + cases.append(already_int_case) + elif m := r_even_round_halves_case.search(case_str): + cases.append(even_round_halves_case) + elif m := r_unary_case.search(case_str): try: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) _check_result, result_expr = parse_result(m.group(2)) @@ -643,8 +676,6 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: check_result=check_result, ) cases.append(case) - elif m := r_even_round_halves_case.search(case_str): - cases.append(even_round_halves_case) else: if not r_remaining_case.search(case_str): warn(f"case not machine-readable: '{case_str}'") @@ -818,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool: return check_result -def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: - """ - Returns a strategy that generates float-casted integers within the bounds of dtype. - """ - for k in kw.keys(): - # sanity check - assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] - m, M = dh.dtype_ranges[dtype] - if "min_value" in kw.keys(): - m = kw["min_value"] - if "exclude_min" in kw.keys(): - m += 1 - if "max_value" in kw.keys(): - M = kw["max_value"] - if "exclude_max" in kw.keys(): - M -= 1 - return st.integers(math.ceil(m), math.floor(M)).map(float) - - def parse_binary_case(case_str: str) -> BinaryCase: """ Parses a Sphinx-formatted binary case string to return codified binary cases, e.g. From 9ef6d69c3cd9ed2899c56bce6435b132be9c983b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 18:52:45 +0100 Subject: [PATCH 09/10] Bump `array-api` submodule --- array-api | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array-api b/array-api index 02fa9237..c5808f2b 160000 --- a/array-api +++ b/array-api @@ -1 +1 @@ -Subproject commit 02fa9237eab3258120778baec12cd38cfd309ee3 +Subproject commit c5808f2b173ea52d813c450bec7b1beaf2973299 From e933be9184d8fe979b90b251082f7e2e9cf2a78e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 13 May 2022 19:08:30 +0100 Subject: [PATCH 10/10] Fix internal parsing docstring examples --- array_api_tests/test_special_cases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 2c14fdfa..e6580863 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -634,7 +634,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: ... an array containing the square root of each element in ``x`` ... ''' ... - >>> case_block = r_case_block.match(sqrt.__doc__).group(1) + >>> case_block = r_case_block.search(sqrt.__doc__).group(1) >>> unary_cases = parse_unary_case_block(case_block) >>> for case in unary_cases: ... print(repr(case)) @@ -1094,7 +1094,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: ... an array containing the results ... ''' ... - >>> case_block = r_case_block.match(logaddexp.__doc__).group(1) + >>> case_block = r_case_block.search(logaddexp.__doc__).group(1) >>> binary_cases = parse_binary_case_block(case_block) >>> for case in binary_cases: ... print(repr(case))