@@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
30
30
return xps .boolean_dtypes () | all_integer_dtypes ()
31
31
32
32
33
+ class OnewayPromotableDtypes (NamedTuple ):
34
+ input_dtype : DataType
35
+ result_dtype : DataType
36
+
37
+
38
+ @st .composite
39
+ def oneway_promotable_dtypes (
40
+ draw , dtypes : List [DataType ]
41
+ ) -> st .SearchStrategy [OnewayPromotableDtypes ]:
42
+ """Return a strategy for input dtypes that promote to result dtypes."""
43
+ d1 , d2 = draw (hh .mutually_promotable_dtypes (dtypes = dtypes ))
44
+ result_dtype = dh .result_type (d1 , d2 )
45
+ if d1 == result_dtype :
46
+ return OnewayPromotableDtypes (d2 , d1 )
47
+ elif d2 == result_dtype :
48
+ return OnewayPromotableDtypes (d1 , d2 )
49
+ else :
50
+ reject ()
51
+
52
+
33
53
class OnewayBroadcastableShapes (NamedTuple ):
34
54
input_shape : Shape
35
55
result_shape : Shape
@@ -326,8 +346,14 @@ def __repr__(self):
326
346
327
347
328
348
def make_binary_params (
329
- elwise_func_name : str , dtypes_strat : st . SearchStrategy [DataType ]
349
+ elwise_func_name : str , dtypes : List [DataType ]
330
350
) -> List [Param [BinaryParamContext ]]:
351
+ if hh .FILTER_UNDEFINED_DTYPES :
352
+ dtypes = [d for d in dtypes if not isinstance (d , xp ._UndefinedStub )]
353
+ shared_oneway_dtypes = st .shared (oneway_promotable_dtypes (dtypes ))
354
+ left_dtypes = shared_oneway_dtypes .map (lambda D : D .result_dtype )
355
+ right_dtypes = shared_oneway_dtypes .map (lambda D : D .input_dtype )
356
+
331
357
def make_param (
332
358
func_name : str , func_type : FuncType , right_is_scalar : bool
333
359
) -> Param [BinaryParamContext ]:
@@ -338,32 +364,29 @@ def make_param(
338
364
left_sym = "x1"
339
365
right_sym = "x2"
340
366
341
- shared_dtypes = st .shared (dtypes_strat )
342
367
if right_is_scalar :
343
- left_strat = xps .arrays (dtype = shared_dtypes , shape = hh .shapes (** shapes_kw ))
344
- right_strat = shared_dtypes .flatmap (
345
- lambda d : xps .from_dtype (d , ** finite_kw )
346
- )
368
+ left_strat = xps .arrays (dtype = left_dtypes , shape = hh .shapes (** shapes_kw ))
369
+ right_strat = right_dtypes .flatmap (lambda d : xps .from_dtype (d , ** finite_kw ))
347
370
else :
348
371
if func_type is FuncType .IOP :
349
372
shared_oneway_shapes = st .shared (oneway_broadcastable_shapes ())
350
373
left_strat = xps .arrays (
351
- dtype = shared_dtypes ,
374
+ dtype = left_dtypes ,
352
375
shape = shared_oneway_shapes .map (lambda S : S .result_shape ),
353
376
)
354
377
right_strat = xps .arrays (
355
- dtype = shared_dtypes ,
378
+ dtype = right_dtypes ,
356
379
shape = shared_oneway_shapes .map (lambda S : S .input_shape ),
357
380
)
358
381
else :
359
382
mutual_shapes = st .shared (
360
383
hh .mutually_broadcastable_shapes (2 , ** shapes_kw )
361
384
)
362
385
left_strat = xps .arrays (
363
- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
386
+ dtype = left_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
364
387
)
365
388
right_strat = xps .arrays (
366
- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
389
+ dtype = right_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
367
390
)
368
391
369
392
if func_type is FuncType .FUNC :
@@ -540,7 +563,7 @@ def test_acosh(x):
540
563
)
541
564
542
565
543
- @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , xps .numeric_dtypes () ))
566
+ @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , dh .numeric_dtypes ))
544
567
@given (data = st .data ())
545
568
def test_add (ctx , data ):
546
569
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -605,7 +628,7 @@ def test_atanh(x):
605
628
606
629
607
630
@pytest .mark .parametrize (
608
- "ctx" , make_binary_params ("bitwise_and" , boolean_and_all_integer_dtypes () )
631
+ "ctx" , make_binary_params ("bitwise_and" , dh . bool_and_all_int_dtypes )
609
632
)
610
633
@given (data = st .data ())
611
634
def test_bitwise_and (ctx , data ):
@@ -624,7 +647,7 @@ def test_bitwise_and(ctx, data):
624
647
625
648
626
649
@pytest .mark .parametrize (
627
- "ctx" , make_binary_params ("bitwise_left_shift" , all_integer_dtypes () )
650
+ "ctx" , make_binary_params ("bitwise_left_shift" , dh . all_int_dtypes )
628
651
)
629
652
@given (data = st .data ())
630
653
def test_bitwise_left_shift (ctx , data ):
@@ -664,7 +687,7 @@ def test_bitwise_invert(ctx, data):
664
687
665
688
666
689
@pytest .mark .parametrize (
667
- "ctx" , make_binary_params ("bitwise_or" , boolean_and_all_integer_dtypes () )
690
+ "ctx" , make_binary_params ("bitwise_or" , dh . bool_and_all_int_dtypes )
668
691
)
669
692
@given (data = st .data ())
670
693
def test_bitwise_or (ctx , data ):
@@ -683,7 +706,7 @@ def test_bitwise_or(ctx, data):
683
706
684
707
685
708
@pytest .mark .parametrize (
686
- "ctx" , make_binary_params ("bitwise_right_shift" , all_integer_dtypes () )
709
+ "ctx" , make_binary_params ("bitwise_right_shift" , dh . all_int_dtypes )
687
710
)
688
711
@given (data = st .data ())
689
712
def test_bitwise_right_shift (ctx , data ):
@@ -704,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):
704
727
705
728
706
729
@pytest .mark .parametrize (
707
- "ctx" , make_binary_params ("bitwise_xor" , boolean_and_all_integer_dtypes () )
730
+ "ctx" , make_binary_params ("bitwise_xor" , dh . bool_and_all_int_dtypes )
708
731
)
709
732
@given (data = st .data ())
710
733
def test_bitwise_xor (ctx , data ):
@@ -746,7 +769,7 @@ def test_cosh(x):
746
769
unary_assert_against_refimpl ("cosh" , x , out , math .cosh )
747
770
748
771
749
- @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , xps . floating_dtypes () ))
772
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , dh . float_dtypes ))
750
773
@given (data = st .data ())
751
774
def test_divide (ctx , data ):
752
775
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -769,7 +792,7 @@ def test_divide(ctx, data):
769
792
)
770
793
771
794
772
- @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , xps . scalar_dtypes () ))
795
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , dh . all_dtypes ))
773
796
@given (data = st .data ())
774
797
def test_equal (ctx , data ):
775
798
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -821,9 +844,7 @@ def test_floor(x):
821
844
unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
822
845
823
846
824
- @pytest .mark .parametrize (
825
- "ctx" , make_binary_params ("floor_divide" , xps .numeric_dtypes ())
826
- )
847
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("floor_divide" , dh .numeric_dtypes ))
827
848
@given (data = st .data ())
828
849
def test_floor_divide (ctx , data ):
829
850
left = data .draw (
@@ -842,7 +863,7 @@ def test_floor_divide(ctx, data):
842
863
binary_param_assert_against_refimpl (ctx , left , right , res , "//" , operator .floordiv )
843
864
844
865
845
- @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , xps .numeric_dtypes () ))
866
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , dh .numeric_dtypes ))
846
867
@given (data = st .data ())
847
868
def test_greater (ctx , data ):
848
869
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -862,9 +883,7 @@ def test_greater(ctx, data):
862
883
)
863
884
864
885
865
- @pytest .mark .parametrize (
866
- "ctx" , make_binary_params ("greater_equal" , xps .numeric_dtypes ())
867
- )
886
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater_equal" , dh .numeric_dtypes ))
868
887
@given (data = st .data ())
869
888
def test_greater_equal (ctx , data ):
870
889
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -908,7 +927,7 @@ def test_isnan(x):
908
927
unary_assert_against_refimpl ("isnan" , x , out , math .isnan , res_stype = bool )
909
928
910
929
911
- @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , xps .numeric_dtypes () ))
930
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , dh .numeric_dtypes ))
912
931
@given (data = st .data ())
913
932
def test_less (ctx , data ):
914
933
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -928,7 +947,7 @@ def test_less(ctx, data):
928
947
)
929
948
930
949
931
- @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , xps .numeric_dtypes () ))
950
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , dh .numeric_dtypes ))
932
951
@given (data = st .data ())
933
952
def test_less_equal (ctx , data ):
934
953
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1040,7 +1059,7 @@ def test_logical_xor(x1, x2):
1040
1059
)
1041
1060
1042
1061
1043
- @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , xps .numeric_dtypes () ))
1062
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
1044
1063
@given (data = st .data ())
1045
1064
def test_multiply (ctx , data ):
1046
1065
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1073,7 +1092,7 @@ def test_negative(ctx, data):
1073
1092
)
1074
1093
1075
1094
1076
- @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , xps . scalar_dtypes () ))
1095
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , dh . all_dtypes ))
1077
1096
@given (data = st .data ())
1078
1097
def test_not_equal (ctx , data ):
1079
1098
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1105,7 +1124,7 @@ def test_positive(ctx, data):
1105
1124
ph .assert_array (ctx .func_name , out , x )
1106
1125
1107
1126
1108
- @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , xps .numeric_dtypes () ))
1127
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , dh .numeric_dtypes ))
1109
1128
@given (data = st .data ())
1110
1129
def test_pow (ctx , data ):
1111
1130
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1129,7 +1148,7 @@ def test_pow(ctx, data):
1129
1148
)
1130
1149
1131
1150
1132
- @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , xps .numeric_dtypes () ))
1151
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , dh .numeric_dtypes ))
1133
1152
@given (data = st .data ())
1134
1153
def test_remainder (ctx , data ):
1135
1154
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1200,7 +1219,7 @@ def test_sqrt(x):
1200
1219
)
1201
1220
1202
1221
1203
- @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , xps .numeric_dtypes () ))
1222
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , dh .numeric_dtypes ))
1204
1223
@given (data = st .data ())
1205
1224
def test_subtract (ctx , data ):
1206
1225
left = data .draw (ctx .left_strat , label = ctx .left_sym )
0 commit comments