4
4
from dataclasses import dataclass
5
5
from decimal import ROUND_HALF_EVEN , Decimal
6
6
from enum import Enum , auto
7
- from typing import Callable , List , Match , Protocol , Tuple
7
+ from typing import Callable , List , Match , Optional , Protocol , Tuple
8
8
from warnings import warn
9
9
10
10
import pytest
@@ -372,33 +372,46 @@ class BinaryCase(Case):
372
372
r_both_inputs_are_value = re .compile ("are both (.+)" )
373
373
374
374
375
- class BinaryCondInput (Enum ):
375
+ class BinaryCondArg (Enum ):
376
376
FIRST = auto ()
377
377
SECOND = auto ()
378
378
BOTH = auto ()
379
379
EITHER = auto ()
380
380
381
+ @classmethod
382
+ def from_x_no (cls , string ):
383
+ if string == "1" :
384
+ return cls .FIRST
385
+ elif string == "2" :
386
+ return cls .SECOND
387
+ else :
388
+ raise ValueError (f"{ string = } not '1' or '2'" )
381
389
382
- def noop (obj ):
383
- return obj
384
390
391
+ def noop (n : float ) -> float :
392
+ return n
385
393
386
- def make_partial_cond (
387
- input_ : BinaryCondInput , unary_check : UnaryCheck , * , input_wrapper = None
394
+
395
+ def make_binary_cond (
396
+ cond_arg : BinaryCondArg ,
397
+ unary_check : UnaryCheck ,
398
+ * ,
399
+ input_wrapper : Optional [Callable [[float ], float ]] = None ,
388
400
) -> BinaryCond :
389
401
if input_wrapper is None :
390
402
input_wrapper = noop
391
- if input_ == BinaryCondInput .FIRST :
403
+
404
+ if cond_arg == BinaryCondArg .FIRST :
392
405
393
406
def partial_cond (i1 : float , i2 : float ) -> bool :
394
407
return unary_check (input_wrapper (i1 ))
395
408
396
- elif input_ == BinaryCondInput .SECOND :
409
+ elif cond_arg == BinaryCondArg .SECOND :
397
410
398
411
def partial_cond (i1 : float , i2 : float ) -> bool :
399
412
return unary_check (input_wrapper (i2 ))
400
413
401
- elif input_ == BinaryCondInput .BOTH :
414
+ elif cond_arg == BinaryCondArg .BOTH :
402
415
403
416
def partial_cond (i1 : float , i2 : float ) -> bool :
404
417
return unary_check (input_wrapper (i1 )) and unary_check (input_wrapper (i2 ))
@@ -411,50 +424,78 @@ def partial_cond(i1: float, i2: float) -> bool:
411
424
return partial_cond
412
425
413
426
414
- def parse_binary_case (case_m : Match ) -> BinaryCase :
415
- cond_strs = r_cond_sep .split (case_m .group (1 ))
416
- partial_conds = []
417
- partial_exprs = []
418
- for cond_str in cond_strs :
419
- if m := r_input_is_array_element .match (cond_str ):
420
- in_sign , input_array , value_sign , value_array = m .groups ()
421
- assert in_sign == "" and value_array != input_array # sanity check
422
- partial_expr = f"{ in_sign } x{ input_array } ᵢ == { value_sign } x{ value_array } ᵢ"
423
- if value_array == "1" :
424
- if value_sign != "-" :
427
+ def make_eq_other_input_cond (
428
+ eq_to : BinaryCondArg , * , eq_neg : bool = False
429
+ ) -> BinaryCond :
430
+ if eq_neg :
431
+ input_wrapper = lambda i : - i
432
+ else :
433
+ input_wrapper = noop
425
434
426
- def partial_cond (i1 : float , i2 : float ) -> bool :
427
- eq = make_eq (i1 )
428
- return eq (i2 )
435
+ if eq_to == BinaryCondArg .FIRST :
429
436
430
- else :
437
+ def cond (i1 : float , i2 : float ) -> bool :
438
+ eq = make_eq (input_wrapper (i1 ))
439
+ return eq (i2 )
431
440
432
- def partial_cond (i1 : float , i2 : float ) -> bool :
433
- eq = make_eq (- i1 )
434
- return eq (i2 )
441
+ elif eq_to == BinaryCondArg .SECOND :
435
442
436
- else :
437
- if value_sign != "-" :
443
+ def cond (i1 : float , i2 : float ) -> bool :
444
+ eq = make_eq (input_wrapper (i2 ))
445
+ return eq (i1 )
446
+
447
+ else :
448
+ raise ValueError (f"{ eq_to = } must be FIRST or SECOND" )
438
449
439
- def partial_cond (i1 : float , i2 : float ) -> bool :
440
- eq = make_eq (i2 )
441
- return eq (i1 )
450
+ return cond
442
451
443
- else :
444
452
445
- def partial_cond (i1 : float , i2 : float ) -> bool :
446
- eq = make_eq (- i2 )
447
- return eq (i1 )
453
+ def make_eq_input_check_result (
454
+ eq_to : BinaryCondArg , * , eq_neg : bool = False
455
+ ) -> BinaryResultCheck :
456
+ if eq_neg :
457
+ input_wrapper = lambda i : - i
458
+ else :
459
+ input_wrapper = noop
460
+
461
+ if eq_to == BinaryCondArg .FIRST :
462
+
463
+ def check_result (i1 : float , i2 : float , result : float ) -> bool :
464
+ eq = make_eq (input_wrapper (i1 ))
465
+ return eq (result )
466
+
467
+ elif eq_to == BinaryCondArg .SECOND :
468
+
469
+ def check_result (i1 : float , i2 : float , result : float ) -> bool :
470
+ eq = make_eq (input_wrapper (i2 ))
471
+ return eq (result )
472
+
473
+ else :
474
+ raise ValueError (f"{ eq_to = } must be FIRST or SECOND" )
475
+
476
+ return check_result
448
477
478
+
479
+ def parse_binary_case (case_m : Match ) -> BinaryCase :
480
+ cond_strs = r_cond_sep .split (case_m .group (1 ))
481
+ partial_conds = []
482
+ partial_exprs = []
483
+ for cond_str in cond_strs :
484
+ if m := r_input_is_array_element .match (cond_str ):
485
+ in_sign , in_no , other_sign , other_no = m .groups ()
486
+ assert in_sign == "" and other_no != in_no # sanity check
487
+ partial_expr = f"{ in_sign } x{ in_no } ᵢ == { other_sign } x{ other_no } ᵢ"
488
+ partial_cond = make_eq_other_input_cond ( # type: ignore
489
+ BinaryCondArg .from_x_no (other_no ), eq_neg = other_sign == "-"
490
+ )
449
491
elif m := r_both_inputs_are_value .match (cond_str ):
450
492
unary_cond , expr_template = parse_cond (m .group (1 ))
451
493
left_expr = expr_template .replace ("{}" , "x1ᵢ" )
452
494
right_expr = expr_template .replace ("{}" , "x2ᵢ" )
453
495
partial_expr = f"({ left_expr } ) and ({ right_expr } )"
454
- partial_cond = make_partial_cond ( # type: ignore
455
- BinaryCondInput .BOTH , unary_cond
496
+ partial_cond = make_binary_cond ( # type: ignore
497
+ BinaryCondArg .BOTH , unary_cond
456
498
)
457
-
458
499
else :
459
500
cond_m = r_cond .match (cond_str )
460
501
if cond_m is None :
@@ -484,32 +525,26 @@ def partial_cond(i1: float, i2: float) -> bool:
484
525
if m := r_input .match (input_str ):
485
526
x_no = m .group (1 )
486
527
partial_expr = expr_template .replace ("{}" , f"x{ x_no } ᵢ" )
487
- if x_no == "1" :
488
- input_ = BinaryCondInput .FIRST
489
- else :
490
- input_ = BinaryCondInput .SECOND
528
+ cond_arg = BinaryCondArg .from_x_no (x_no )
491
529
elif m := r_abs_input .match (input_str ):
492
530
x_no = m .group (1 )
493
531
partial_expr = expr_template .replace ("{}" , f"abs(x{ x_no } ᵢ)" )
494
- if x_no == "1" :
495
- input_ = BinaryCondInput .FIRST
496
- else :
497
- input_ = BinaryCondInput .SECOND
532
+ cond_arg = BinaryCondArg .from_x_no (x_no )
498
533
input_wrapper = abs
499
534
elif r_and_input .match (input_str ):
500
535
left_expr = expr_template .replace ("{}" , "x1ᵢ" )
501
536
right_expr = expr_template .replace ("{}" , "x2ᵢ" )
502
537
partial_expr = f"({ left_expr } ) and ({ right_expr } )"
503
- input_ = BinaryCondInput .BOTH
538
+ cond_arg = BinaryCondArg .BOTH
504
539
elif r_or_input .match (input_str ):
505
540
left_expr = expr_template .replace ("{}" , "x1ᵢ" )
506
541
right_expr = expr_template .replace ("{}" , "x2ᵢ" )
507
542
partial_expr = f"({ left_expr } ) or ({ right_expr } )"
508
- input_ = BinaryCondInput .EITHER
543
+ cond_arg = BinaryCondArg .EITHER
509
544
else :
510
545
raise ValueParseError (input_str )
511
- partial_cond = make_partial_cond ( # type: ignore
512
- input_ , unary_check , input_wrapper = input_wrapper
546
+ partial_cond = make_binary_cond ( # type: ignore
547
+ cond_arg , unary_check , input_wrapper = input_wrapper
513
548
)
514
549
515
550
partial_conds .append (partial_cond )
@@ -520,34 +555,11 @@ def partial_cond(i1: float, i2: float) -> bool:
520
555
raise ValueParseError (case_m .group (2 ))
521
556
result_str = result_m .group (1 )
522
557
if m := r_array_element .match (result_str ):
523
- sign , input_ = m .groups ()
524
- result_expr = f"{ sign } x{ input_ } ᵢ"
525
- if input_ == "1" :
526
- if sign != "-" :
527
-
528
- def check_result (i1 : float , i2 : float , result : float ) -> bool :
529
- eq = make_eq (i1 )
530
- return eq (result )
531
-
532
- else :
533
-
534
- def check_result (i1 : float , i2 : float , result : float ) -> bool :
535
- eq = make_eq (- i1 )
536
- return eq (result )
537
-
538
- else :
539
- if sign != "-" :
540
-
541
- def check_result (i1 : float , i2 : float , result : float ) -> bool :
542
- eq = make_eq (i2 )
543
- return eq (result )
544
-
545
- else :
546
-
547
- def check_result (i1 : float , i2 : float , result : float ) -> bool :
548
- eq = make_eq (- i2 )
549
- return eq (result )
550
-
558
+ sign , x_no = m .groups ()
559
+ result_expr = f"{ sign } x{ x_no } ᵢ"
560
+ check_result = make_eq_input_check_result ( # type: ignore
561
+ BinaryCondArg .from_x_no (x_no ), eq_neg = sign == "-"
562
+ )
551
563
else :
552
564
_check_result , result_expr = parse_result (result_m .group (1 ))
553
565
0 commit comments