@@ -535,12 +535,34 @@ class UnaryCase(Case):
535
535
536
536
537
537
r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
538
+ r_already_int_case = re .compile (
539
+ "If ``x_i`` is already integer-valued, the result is ``x_i``"
540
+ )
538
541
r_even_round_halves_case = re .compile (
539
542
"If two integers are equally close to ``x_i``, "
540
543
"the result is the even integer closest to ``x_i``"
541
544
)
542
545
543
546
547
+ def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
548
+ """
549
+ Returns a strategy that generates float-casted integers within the bounds of dtype.
550
+ """
551
+ for k in kw .keys ():
552
+ # sanity check
553
+ assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
554
+ m , M = dh .dtype_ranges [dtype ]
555
+ if "min_value" in kw .keys ():
556
+ m = kw ["min_value" ]
557
+ if "exclude_min" in kw .keys ():
558
+ m += 1
559
+ if "max_value" in kw .keys ():
560
+ M = kw ["max_value" ]
561
+ if "exclude_max" in kw .keys ():
562
+ M -= 1
563
+ return st .integers (math .ceil (m ), math .floor (M )).map (float )
564
+
565
+
544
566
def trailing_halves_from_dtype (dtype : DataType ) -> st .SearchStrategy [float ]:
545
567
"""
546
568
Returns a strategy that generates floats that end with .5 and are within the
@@ -557,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
557
579
)
558
580
559
581
582
+ already_int_case = UnaryCase (
583
+ cond_expr = "x_i.is_integer()" ,
584
+ cond = lambda i : i .is_integer (),
585
+ cond_from_dtype = integers_from_dtype ,
586
+ result_expr = "x_i" ,
587
+ check_result = lambda i , result : i == result ,
588
+ )
560
589
even_round_halves_case = UnaryCase (
561
590
cond_expr = "modf(i)[0] == 0.5" ,
562
591
cond = lambda i : math .modf (i )[0 ] == 0.5 ,
@@ -624,7 +653,11 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
624
653
cases = []
625
654
for case_m in r_case .finditer (case_block ):
626
655
case_str = case_m .group (1 )
627
- if m := r_unary_case .search (case_str ):
656
+ if m := r_already_int_case .search (case_str ):
657
+ cases .append (already_int_case )
658
+ elif m := r_even_round_halves_case .search (case_str ):
659
+ cases .append (even_round_halves_case )
660
+ elif m := r_unary_case .search (case_str ):
628
661
try :
629
662
cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
630
663
_check_result , result_expr = parse_result (m .group (2 ))
@@ -643,8 +676,6 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
643
676
check_result = check_result ,
644
677
)
645
678
cases .append (case )
646
- elif m := r_even_round_halves_case .search (case_str ):
647
- cases .append (even_round_halves_case )
648
679
else :
649
680
if not r_remaining_case .search (case_str ):
650
681
warn (f"case not machine-readable: '{ case_str } '" )
@@ -818,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
818
849
return check_result
819
850
820
851
821
- def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
822
- """
823
- Returns a strategy that generates float-casted integers within the bounds of dtype.
824
- """
825
- for k in kw .keys ():
826
- # sanity check
827
- assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
828
- m , M = dh .dtype_ranges [dtype ]
829
- if "min_value" in kw .keys ():
830
- m = kw ["min_value" ]
831
- if "exclude_min" in kw .keys ():
832
- m += 1
833
- if "max_value" in kw .keys ():
834
- M = kw ["max_value" ]
835
- if "exclude_max" in kw .keys ():
836
- M -= 1
837
- return st .integers (math .ceil (m ), math .floor (M )).map (float )
838
-
839
-
840
852
def parse_binary_case (case_str : str ) -> BinaryCase :
841
853
"""
842
854
Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
0 commit comments