Skip to content

Commit 56d452d

Browse files
committed
Factor out logic for input eq conds/results
1 parent 6f05c53 commit 56d452d

File tree

1 file changed

+91
-79
lines changed

1 file changed

+91
-79
lines changed

array_api_tests/test_special_cases.py

+91-79
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from decimal import ROUND_HALF_EVEN, Decimal
66
from enum import Enum, auto
7-
from typing import Callable, List, Match, Protocol, Tuple
7+
from typing import Callable, List, Match, Optional, Protocol, Tuple
88
from warnings import warn
99

1010
import pytest
@@ -372,33 +372,46 @@ class BinaryCase(Case):
372372
r_both_inputs_are_value = re.compile("are both (.+)")
373373

374374

375-
class BinaryCondInput(Enum):
375+
class BinaryCondArg(Enum):
376376
FIRST = auto()
377377
SECOND = auto()
378378
BOTH = auto()
379379
EITHER = auto()
380380

381+
@classmethod
382+
def from_x_no(cls, string):
383+
if string == "1":
384+
return cls.FIRST
385+
elif string == "2":
386+
return cls.SECOND
387+
else:
388+
raise ValueError(f"{string=} not '1' or '2'")
381389

382-
def noop(obj):
383-
return obj
384390

391+
def noop(n: float) -> float:
392+
return n
385393

386-
def make_partial_cond(
387-
input_: BinaryCondInput, unary_check: UnaryCheck, *, input_wrapper=None
394+
395+
def make_binary_cond(
396+
cond_arg: BinaryCondArg,
397+
unary_check: UnaryCheck,
398+
*,
399+
input_wrapper: Optional[Callable[[float], float]] = None,
388400
) -> BinaryCond:
389401
if input_wrapper is None:
390402
input_wrapper = noop
391-
if input_ == BinaryCondInput.FIRST:
403+
404+
if cond_arg == BinaryCondArg.FIRST:
392405

393406
def partial_cond(i1: float, i2: float) -> bool:
394407
return unary_check(input_wrapper(i1))
395408

396-
elif input_ == BinaryCondInput.SECOND:
409+
elif cond_arg == BinaryCondArg.SECOND:
397410

398411
def partial_cond(i1: float, i2: float) -> bool:
399412
return unary_check(input_wrapper(i2))
400413

401-
elif input_ == BinaryCondInput.BOTH:
414+
elif cond_arg == BinaryCondArg.BOTH:
402415

403416
def partial_cond(i1: float, i2: float) -> bool:
404417
return unary_check(input_wrapper(i1)) and unary_check(input_wrapper(i2))
@@ -411,50 +424,78 @@ def partial_cond(i1: float, i2: float) -> bool:
411424
return partial_cond
412425

413426

414-
def parse_binary_case(case_m: Match) -> BinaryCase:
415-
cond_strs = r_cond_sep.split(case_m.group(1))
416-
partial_conds = []
417-
partial_exprs = []
418-
for cond_str in cond_strs:
419-
if m := r_input_is_array_element.match(cond_str):
420-
in_sign, input_array, value_sign, value_array = m.groups()
421-
assert in_sign == "" and value_array != input_array # sanity check
422-
partial_expr = f"{in_sign}x{input_array}ᵢ == {value_sign}x{value_array}ᵢ"
423-
if value_array == "1":
424-
if value_sign != "-":
427+
def make_eq_other_input_cond(
428+
eq_to: BinaryCondArg, *, eq_neg: bool = False
429+
) -> BinaryCond:
430+
if eq_neg:
431+
input_wrapper = lambda i: -i
432+
else:
433+
input_wrapper = noop
425434

426-
def partial_cond(i1: float, i2: float) -> bool:
427-
eq = make_eq(i1)
428-
return eq(i2)
435+
if eq_to == BinaryCondArg.FIRST:
429436

430-
else:
437+
def cond(i1: float, i2: float) -> bool:
438+
eq = make_eq(input_wrapper(i1))
439+
return eq(i2)
431440

432-
def partial_cond(i1: float, i2: float) -> bool:
433-
eq = make_eq(-i1)
434-
return eq(i2)
441+
elif eq_to == BinaryCondArg.SECOND:
435442

436-
else:
437-
if value_sign != "-":
443+
def cond(i1: float, i2: float) -> bool:
444+
eq = make_eq(input_wrapper(i2))
445+
return eq(i1)
446+
447+
else:
448+
raise ValueError(f"{eq_to=} must be FIRST or SECOND")
438449

439-
def partial_cond(i1: float, i2: float) -> bool:
440-
eq = make_eq(i2)
441-
return eq(i1)
450+
return cond
442451

443-
else:
444452

445-
def partial_cond(i1: float, i2: float) -> bool:
446-
eq = make_eq(-i2)
447-
return eq(i1)
453+
def make_eq_input_check_result(
454+
eq_to: BinaryCondArg, *, eq_neg: bool = False
455+
) -> BinaryResultCheck:
456+
if eq_neg:
457+
input_wrapper = lambda i: -i
458+
else:
459+
input_wrapper = noop
460+
461+
if eq_to == BinaryCondArg.FIRST:
462+
463+
def check_result(i1: float, i2: float, result: float) -> bool:
464+
eq = make_eq(input_wrapper(i1))
465+
return eq(result)
466+
467+
elif eq_to == BinaryCondArg.SECOND:
468+
469+
def check_result(i1: float, i2: float, result: float) -> bool:
470+
eq = make_eq(input_wrapper(i2))
471+
return eq(result)
472+
473+
else:
474+
raise ValueError(f"{eq_to=} must be FIRST or SECOND")
475+
476+
return check_result
448477

478+
479+
def parse_binary_case(case_m: Match) -> BinaryCase:
480+
cond_strs = r_cond_sep.split(case_m.group(1))
481+
partial_conds = []
482+
partial_exprs = []
483+
for cond_str in cond_strs:
484+
if m := r_input_is_array_element.match(cond_str):
485+
in_sign, in_no, other_sign, other_no = m.groups()
486+
assert in_sign == "" and other_no != in_no # sanity check
487+
partial_expr = f"{in_sign}x{in_no}ᵢ == {other_sign}x{other_no}ᵢ"
488+
partial_cond = make_eq_other_input_cond( # type: ignore
489+
BinaryCondArg.from_x_no(other_no), eq_neg=other_sign == "-"
490+
)
449491
elif m := r_both_inputs_are_value.match(cond_str):
450492
unary_cond, expr_template = parse_cond(m.group(1))
451493
left_expr = expr_template.replace("{}", "x1ᵢ")
452494
right_expr = expr_template.replace("{}", "x2ᵢ")
453495
partial_expr = f"({left_expr}) and ({right_expr})"
454-
partial_cond = make_partial_cond( # type: ignore
455-
BinaryCondInput.BOTH, unary_cond
496+
partial_cond = make_binary_cond( # type: ignore
497+
BinaryCondArg.BOTH, unary_cond
456498
)
457-
458499
else:
459500
cond_m = r_cond.match(cond_str)
460501
if cond_m is None:
@@ -484,32 +525,26 @@ def partial_cond(i1: float, i2: float) -> bool:
484525
if m := r_input.match(input_str):
485526
x_no = m.group(1)
486527
partial_expr = expr_template.replace("{}", f"x{x_no}ᵢ")
487-
if x_no == "1":
488-
input_ = BinaryCondInput.FIRST
489-
else:
490-
input_ = BinaryCondInput.SECOND
528+
cond_arg = BinaryCondArg.from_x_no(x_no)
491529
elif m := r_abs_input.match(input_str):
492530
x_no = m.group(1)
493531
partial_expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)")
494-
if x_no == "1":
495-
input_ = BinaryCondInput.FIRST
496-
else:
497-
input_ = BinaryCondInput.SECOND
532+
cond_arg = BinaryCondArg.from_x_no(x_no)
498533
input_wrapper = abs
499534
elif r_and_input.match(input_str):
500535
left_expr = expr_template.replace("{}", "x1ᵢ")
501536
right_expr = expr_template.replace("{}", "x2ᵢ")
502537
partial_expr = f"({left_expr}) and ({right_expr})"
503-
input_ = BinaryCondInput.BOTH
538+
cond_arg = BinaryCondArg.BOTH
504539
elif r_or_input.match(input_str):
505540
left_expr = expr_template.replace("{}", "x1ᵢ")
506541
right_expr = expr_template.replace("{}", "x2ᵢ")
507542
partial_expr = f"({left_expr}) or ({right_expr})"
508-
input_ = BinaryCondInput.EITHER
543+
cond_arg = BinaryCondArg.EITHER
509544
else:
510545
raise ValueParseError(input_str)
511-
partial_cond = make_partial_cond( # type: ignore
512-
input_, unary_check, input_wrapper=input_wrapper
546+
partial_cond = make_binary_cond( # type: ignore
547+
cond_arg, unary_check, input_wrapper=input_wrapper
513548
)
514549

515550
partial_conds.append(partial_cond)
@@ -520,34 +555,11 @@ def partial_cond(i1: float, i2: float) -> bool:
520555
raise ValueParseError(case_m.group(2))
521556
result_str = result_m.group(1)
522557
if m := r_array_element.match(result_str):
523-
sign, input_ = m.groups()
524-
result_expr = f"{sign}x{input_}ᵢ"
525-
if input_ == "1":
526-
if sign != "-":
527-
528-
def check_result(i1: float, i2: float, result: float) -> bool:
529-
eq = make_eq(i1)
530-
return eq(result)
531-
532-
else:
533-
534-
def check_result(i1: float, i2: float, result: float) -> bool:
535-
eq = make_eq(-i1)
536-
return eq(result)
537-
538-
else:
539-
if sign != "-":
540-
541-
def check_result(i1: float, i2: float, result: float) -> bool:
542-
eq = make_eq(i2)
543-
return eq(result)
544-
545-
else:
546-
547-
def check_result(i1: float, i2: float, result: float) -> bool:
548-
eq = make_eq(-i2)
549-
return eq(result)
550-
558+
sign, x_no = m.groups()
559+
result_expr = f"{sign}x{x_no}ᵢ"
560+
check_result = make_eq_input_check_result( # type: ignore
561+
BinaryCondArg.from_x_no(x_no), eq_neg=sign == "-"
562+
)
551563
else:
552564
_check_result, result_expr = parse_result(result_m.group(1))
553565

0 commit comments

Comments
 (0)