Skip to content

Commit 2e5322b

Browse files
committed
Parse "already integer-valued" special cases
1 parent ba6b87c commit 2e5322b

File tree

1 file changed

+34
-22
lines changed

1 file changed

+34
-22
lines changed

array_api_tests/test_special_cases.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -535,12 +535,34 @@ class UnaryCase(Case):
535535

536536

537537
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
538+
r_already_int_case = re.compile(
539+
"If ``x_i`` is already integer-valued, the result is ``x_i``"
540+
)
538541
r_even_round_halves_case = re.compile(
539542
"If two integers are equally close to ``x_i``, "
540543
"the result is the even integer closest to ``x_i``"
541544
)
542545

543546

547+
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
548+
"""
549+
Returns a strategy that generates float-casted integers within the bounds of dtype.
550+
"""
551+
for k in kw.keys():
552+
# sanity check
553+
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
554+
m, M = dh.dtype_ranges[dtype]
555+
if "min_value" in kw.keys():
556+
m = kw["min_value"]
557+
if "exclude_min" in kw.keys():
558+
m += 1
559+
if "max_value" in kw.keys():
560+
M = kw["max_value"]
561+
if "exclude_max" in kw.keys():
562+
M -= 1
563+
return st.integers(math.ceil(m), math.floor(M)).map(float)
564+
565+
544566
def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
545567
"""
546568
Returns a strategy that generates floats that end with .5 and are within the
@@ -557,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
557579
)
558580

559581

582+
already_int_case = UnaryCase(
583+
cond_expr="x_i.is_integer()",
584+
cond=lambda i: i.is_integer(),
585+
cond_from_dtype=integers_from_dtype,
586+
result_expr="x_i",
587+
check_result=lambda i, result: i == result,
588+
)
560589
even_round_halves_case = UnaryCase(
561590
cond_expr="modf(i)[0] == 0.5",
562591
cond=lambda i: math.modf(i)[0] == 0.5,
@@ -624,7 +653,11 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
624653
cases = []
625654
for case_m in r_case.finditer(case_block):
626655
case_str = case_m.group(1)
627-
if m := r_unary_case.search(case_str):
656+
if m := r_already_int_case.search(case_str):
657+
cases.append(already_int_case)
658+
elif m := r_even_round_halves_case.search(case_str):
659+
cases.append(even_round_halves_case)
660+
elif m := r_unary_case.search(case_str):
628661
try:
629662
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
630663
_check_result, result_expr = parse_result(m.group(2))
@@ -643,8 +676,6 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
643676
check_result=check_result,
644677
)
645678
cases.append(case)
646-
elif m := r_even_round_halves_case.search(case_str):
647-
cases.append(even_round_halves_case)
648679
else:
649680
if not r_remaining_case.search(case_str):
650681
warn(f"case not machine-readable: '{case_str}'")
@@ -818,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
818849
return check_result
819850

820851

821-
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
822-
"""
823-
Returns a strategy that generates float-casted integers within the bounds of dtype.
824-
"""
825-
for k in kw.keys():
826-
# sanity check
827-
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
828-
m, M = dh.dtype_ranges[dtype]
829-
if "min_value" in kw.keys():
830-
m = kw["min_value"]
831-
if "exclude_min" in kw.keys():
832-
m += 1
833-
if "max_value" in kw.keys():
834-
M = kw["max_value"]
835-
if "exclude_max" in kw.keys():
836-
M -= 1
837-
return st.integers(math.ceil(m), math.floor(M)).map(float)
838-
839-
840852
def parse_binary_case(case_str: str) -> BinaryCase:
841853
"""
842854
Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.

0 commit comments

Comments
 (0)