Skip to content

Commit c8d2339

Browse files
committed
Generate x1 is x2 (and visa verse) conds healthily
1 parent 74a73a9 commit c8d2339

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

array_api_tests/test_special_cases.py

+32-30
Original file line numberDiff line numberDiff line change
@@ -567,32 +567,6 @@ def partial_cond(i1: float, i2: float) -> bool:
567567
return partial_cond
568568

569569

570-
def make_eq_other_input_cond(
571-
eq_to: BinaryCondArg, *, eq_neg: bool = False
572-
) -> BinaryCond:
573-
if eq_neg:
574-
input_wrapper = lambda i: -i
575-
else:
576-
input_wrapper = noop
577-
578-
if eq_to == BinaryCondArg.FIRST:
579-
580-
def cond(i1: float, i2: float) -> bool:
581-
eq = make_strict_eq(input_wrapper(i1))
582-
return eq(i2)
583-
584-
elif eq_to == BinaryCondArg.SECOND:
585-
586-
def cond(i1: float, i2: float) -> bool:
587-
eq = make_strict_eq(input_wrapper(i2))
588-
return eq(i1)
589-
590-
else:
591-
raise ValueError(f"{eq_to=} must be FIRST or SECOND")
592-
593-
return cond
594-
595-
596570
def make_eq_input_check_result(
597571
eq_to: BinaryCondArg, *, eq_neg: bool = False
598572
) -> BinaryResultCheck:
@@ -616,8 +590,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
616590
else:
617591
raise ValueError(f"{eq_to=} must be FIRST or SECOND")
618592

619-
return check_result
620-
621593

622594
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
623595
for k in kw.keys():
@@ -649,9 +621,39 @@ def parse_binary_case(case_str: str) -> BinaryCase:
649621
in_sign, in_no, other_sign, other_no = m.groups()
650622
assert in_sign == "" and other_no != in_no # sanity check
651623
partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i"
652-
partial_cond = make_eq_other_input_cond( # type: ignore
653-
BinaryCondArg.from_x_no(other_no), eq_neg=other_sign == "-"
624+
input_wrapper = lambda i: -i if other_sign == "-" else noop
625+
shared_from_dtype = lambda d, **kw: st.shared(
626+
xps.from_dtype(d, **kw), key=cond_str
654627
)
628+
629+
if other_no == "1":
630+
631+
def partial_cond(i1: float, i2: float) -> bool:
632+
eq = make_strict_eq(input_wrapper(i1))
633+
return eq(i2)
634+
635+
_x2_cond_from_dtype = shared_from_dtype # type: ignore
636+
637+
def _x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
638+
return shared_from_dtype(dtype, **kw).map(input_wrapper)
639+
640+
elif other_no == "2":
641+
642+
def partial_cond(i1: float, i2: float) -> bool:
643+
eq = make_strict_eq(input_wrapper(i2))
644+
return eq(i1)
645+
646+
_x1_cond_from_dtype = shared_from_dtype # type: ignore
647+
648+
def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
649+
return shared_from_dtype(dtype, **kw).map(input_wrapper)
650+
651+
else:
652+
raise ValueParseError(cond_str)
653+
654+
x1_cond_from_dtypes.append(BoundFromDtype(base_func=_x1_cond_from_dtype))
655+
x2_cond_from_dtypes.append(BoundFromDtype(base_func=_x2_cond_from_dtype))
656+
655657
elif m := r_both_inputs_are_value.match(cond_str):
656658
unary_cond, expr_template, cond_from_dtype = parse_cond(m.group(1))
657659
left_expr = expr_template.replace("{}", "x1_i")

0 commit comments

Comments
 (0)