@@ -567,32 +567,6 @@ def partial_cond(i1: float, i2: float) -> bool:
567
567
return partial_cond
568
568
569
569
570
- def make_eq_other_input_cond (
571
- eq_to : BinaryCondArg , * , eq_neg : bool = False
572
- ) -> BinaryCond :
573
- if eq_neg :
574
- input_wrapper = lambda i : - i
575
- else :
576
- input_wrapper = noop
577
-
578
- if eq_to == BinaryCondArg .FIRST :
579
-
580
- def cond (i1 : float , i2 : float ) -> bool :
581
- eq = make_strict_eq (input_wrapper (i1 ))
582
- return eq (i2 )
583
-
584
- elif eq_to == BinaryCondArg .SECOND :
585
-
586
- def cond (i1 : float , i2 : float ) -> bool :
587
- eq = make_strict_eq (input_wrapper (i2 ))
588
- return eq (i1 )
589
-
590
- else :
591
- raise ValueError (f"{ eq_to = } must be FIRST or SECOND" )
592
-
593
- return cond
594
-
595
-
596
570
def make_eq_input_check_result (
597
571
eq_to : BinaryCondArg , * , eq_neg : bool = False
598
572
) -> BinaryResultCheck :
@@ -616,8 +590,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
616
590
else :
617
591
raise ValueError (f"{ eq_to = } must be FIRST or SECOND" )
618
592
619
- return check_result
620
-
621
593
622
594
def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
623
595
for k in kw .keys ():
@@ -649,9 +621,39 @@ def parse_binary_case(case_str: str) -> BinaryCase:
649
621
in_sign , in_no , other_sign , other_no = m .groups ()
650
622
assert in_sign == "" and other_no != in_no # sanity check
651
623
partial_expr = f"{ in_sign } x{ in_no } _i == { other_sign } x{ other_no } _i"
652
- partial_cond = make_eq_other_input_cond ( # type: ignore
653
- BinaryCondArg .from_x_no (other_no ), eq_neg = other_sign == "-"
624
+ input_wrapper = lambda i : - i if other_sign == "-" else noop
625
+ shared_from_dtype = lambda d , ** kw : st .shared (
626
+ xps .from_dtype (d , ** kw ), key = cond_str
654
627
)
628
+
629
+ if other_no == "1" :
630
+
631
+ def partial_cond (i1 : float , i2 : float ) -> bool :
632
+ eq = make_strict_eq (input_wrapper (i1 ))
633
+ return eq (i2 )
634
+
635
+ _x2_cond_from_dtype = shared_from_dtype # type: ignore
636
+
637
+ def _x1_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
638
+ return shared_from_dtype (dtype , ** kw ).map (input_wrapper )
639
+
640
+ elif other_no == "2" :
641
+
642
+ def partial_cond (i1 : float , i2 : float ) -> bool :
643
+ eq = make_strict_eq (input_wrapper (i2 ))
644
+ return eq (i1 )
645
+
646
+ _x1_cond_from_dtype = shared_from_dtype # type: ignore
647
+
648
+ def _x2_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
649
+ return shared_from_dtype (dtype , ** kw ).map (input_wrapper )
650
+
651
+ else :
652
+ raise ValueParseError (cond_str )
653
+
654
+ x1_cond_from_dtypes .append (BoundFromDtype (base_func = _x1_cond_from_dtype ))
655
+ x2_cond_from_dtypes .append (BoundFromDtype (base_func = _x2_cond_from_dtype ))
656
+
655
657
elif m := r_both_inputs_are_value .match (cond_str ):
656
658
unary_cond , expr_template , cond_from_dtype = parse_cond (m .group (1 ))
657
659
left_expr = expr_template .replace ("{}" , "x1_i" )
0 commit comments