@@ -30,6 +30,46 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
30
30
return xps .boolean_dtypes () | all_integer_dtypes ()
31
31
32
32
33
+ class OnewayPromotableDtypes (NamedTuple ):
34
+ input_dtype : DataType
35
+ result_dtype : DataType
36
+
37
+
38
+ @st .composite
39
+ def oneway_promotable_dtypes (
40
+ draw , dtypes : List [DataType ]
41
+ ) -> st .SearchStrategy [OnewayPromotableDtypes ]:
42
+ """Return a strategy for input dtypes that promote to result dtypes."""
43
+ d1 , d2 = draw (hh .mutually_promotable_dtypes (dtypes = dtypes ))
44
+ result_dtype = dh .result_type (d1 , d2 )
45
+ if d1 == result_dtype :
46
+ return OnewayPromotableDtypes (d2 , d1 )
47
+ elif d2 == result_dtype :
48
+ return OnewayPromotableDtypes (d1 , d2 )
49
+ else :
50
+ reject ()
51
+
52
+
53
+ class OnewayBroadcastableShapes (NamedTuple ):
54
+ input_shape : Shape
55
+ result_shape : Shape
56
+
57
+
58
+ @st .composite
59
+ def oneway_broadcastable_shapes (draw ) -> st .SearchStrategy [OnewayBroadcastableShapes ]:
60
+ """Return a strategy for input shapes that broadcast to result shapes."""
61
+ result_shape = draw (hh .shapes (min_side = 1 ))
62
+ input_shape = draw (
63
+ xps .broadcastable_shapes (
64
+ result_shape ,
65
+ # Override defaults so bad shapes are less likely to be generated.
66
+ max_side = None if result_shape == () else max (result_shape ),
67
+ max_dims = len (result_shape ),
68
+ ).filter (lambda s : sh .broadcast_shapes (result_shape , s ) == result_shape )
69
+ )
70
+ return OnewayBroadcastableShapes (input_shape , result_shape )
71
+
72
+
33
73
def mock_int_dtype (n : int , dtype : DataType ) -> int :
34
74
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
35
75
nbits = dh .dtype_nbits [dtype ]
@@ -306,8 +346,14 @@ def __repr__(self):
306
346
307
347
308
348
def make_binary_params (
309
- elwise_func_name : str , dtypes_strat : st . SearchStrategy [DataType ]
349
+ elwise_func_name : str , dtypes : List [DataType ]
310
350
) -> List [Param [BinaryParamContext ]]:
351
+ if hh .FILTER_UNDEFINED_DTYPES :
352
+ dtypes = [d for d in dtypes if not isinstance (d , xp ._UndefinedStub )]
353
+ shared_oneway_dtypes = st .shared (oneway_promotable_dtypes (dtypes ))
354
+ left_dtypes = shared_oneway_dtypes .map (lambda D : D .result_dtype )
355
+ right_dtypes = shared_oneway_dtypes .map (lambda D : D .input_dtype )
356
+
311
357
def make_param (
312
358
func_name : str , func_type : FuncType , right_is_scalar : bool
313
359
) -> Param [BinaryParamContext ]:
@@ -318,26 +364,29 @@ def make_param(
318
364
left_sym = "x1"
319
365
right_sym = "x2"
320
366
321
- shared_dtypes = st .shared (dtypes_strat )
322
367
if right_is_scalar :
323
- left_strat = xps .arrays (dtype = shared_dtypes , shape = hh .shapes (** shapes_kw ))
324
- right_strat = shared_dtypes .flatmap (
325
- lambda d : xps .from_dtype (d , ** finite_kw )
326
- )
368
+ left_strat = xps .arrays (dtype = left_dtypes , shape = hh .shapes (** shapes_kw ))
369
+ right_strat = right_dtypes .flatmap (lambda d : xps .from_dtype (d , ** finite_kw ))
327
370
else :
328
371
if func_type is FuncType .IOP :
329
- shared_shapes = st .shared (hh .shapes (** shapes_kw ))
330
- left_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
331
- right_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
372
+ shared_oneway_shapes = st .shared (oneway_broadcastable_shapes ())
373
+ left_strat = xps .arrays (
374
+ dtype = left_dtypes ,
375
+ shape = shared_oneway_shapes .map (lambda S : S .result_shape ),
376
+ )
377
+ right_strat = xps .arrays (
378
+ dtype = right_dtypes ,
379
+ shape = shared_oneway_shapes .map (lambda S : S .input_shape ),
380
+ )
332
381
else :
333
382
mutual_shapes = st .shared (
334
383
hh .mutually_broadcastable_shapes (2 , ** shapes_kw )
335
384
)
336
385
left_strat = xps .arrays (
337
- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
386
+ dtype = left_dtypes , shape = mutual_shapes .map (lambda pair : pair [0 ])
338
387
)
339
388
right_strat = xps .arrays (
340
- dtype = shared_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
389
+ dtype = right_dtypes , shape = mutual_shapes .map (lambda pair : pair [1 ])
341
390
)
342
391
343
392
if func_type is FuncType .FUNC :
@@ -514,7 +563,7 @@ def test_acosh(x):
514
563
)
515
564
516
565
517
- @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , xps .numeric_dtypes () ))
566
+ @pytest .mark .parametrize ("ctx," , make_binary_params ("add" , dh .numeric_dtypes ))
518
567
@given (data = st .data ())
519
568
def test_add (ctx , data ):
520
569
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -579,7 +628,7 @@ def test_atanh(x):
579
628
580
629
581
630
@pytest .mark .parametrize (
582
- "ctx" , make_binary_params ("bitwise_and" , boolean_and_all_integer_dtypes () )
631
+ "ctx" , make_binary_params ("bitwise_and" , dh . bool_and_all_int_dtypes )
583
632
)
584
633
@given (data = st .data ())
585
634
def test_bitwise_and (ctx , data ):
@@ -598,7 +647,7 @@ def test_bitwise_and(ctx, data):
598
647
599
648
600
649
@pytest .mark .parametrize (
601
- "ctx" , make_binary_params ("bitwise_left_shift" , all_integer_dtypes () )
650
+ "ctx" , make_binary_params ("bitwise_left_shift" , dh . all_int_dtypes )
602
651
)
603
652
@given (data = st .data ())
604
653
def test_bitwise_left_shift (ctx , data ):
@@ -638,7 +687,7 @@ def test_bitwise_invert(ctx, data):
638
687
639
688
640
689
@pytest .mark .parametrize (
641
- "ctx" , make_binary_params ("bitwise_or" , boolean_and_all_integer_dtypes () )
690
+ "ctx" , make_binary_params ("bitwise_or" , dh . bool_and_all_int_dtypes )
642
691
)
643
692
@given (data = st .data ())
644
693
def test_bitwise_or (ctx , data ):
@@ -657,7 +706,7 @@ def test_bitwise_or(ctx, data):
657
706
658
707
659
708
@pytest .mark .parametrize (
660
- "ctx" , make_binary_params ("bitwise_right_shift" , all_integer_dtypes () )
709
+ "ctx" , make_binary_params ("bitwise_right_shift" , dh . all_int_dtypes )
661
710
)
662
711
@given (data = st .data ())
663
712
def test_bitwise_right_shift (ctx , data ):
@@ -678,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):
678
727
679
728
680
729
@pytest .mark .parametrize (
681
- "ctx" , make_binary_params ("bitwise_xor" , boolean_and_all_integer_dtypes () )
730
+ "ctx" , make_binary_params ("bitwise_xor" , dh . bool_and_all_int_dtypes )
682
731
)
683
732
@given (data = st .data ())
684
733
def test_bitwise_xor (ctx , data ):
@@ -720,7 +769,7 @@ def test_cosh(x):
720
769
unary_assert_against_refimpl ("cosh" , x , out , math .cosh )
721
770
722
771
723
- @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , xps . floating_dtypes () ))
772
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , dh . float_dtypes ))
724
773
@given (data = st .data ())
725
774
def test_divide (ctx , data ):
726
775
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -743,7 +792,7 @@ def test_divide(ctx, data):
743
792
)
744
793
745
794
746
- @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , xps . scalar_dtypes () ))
795
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , dh . all_dtypes ))
747
796
@given (data = st .data ())
748
797
def test_equal (ctx , data ):
749
798
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -795,9 +844,7 @@ def test_floor(x):
795
844
unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
796
845
797
846
798
- @pytest .mark .parametrize (
799
- "ctx" , make_binary_params ("floor_divide" , xps .numeric_dtypes ())
800
- )
847
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("floor_divide" , dh .numeric_dtypes ))
801
848
@given (data = st .data ())
802
849
def test_floor_divide (ctx , data ):
803
850
left = data .draw (
@@ -816,7 +863,7 @@ def test_floor_divide(ctx, data):
816
863
binary_param_assert_against_refimpl (ctx , left , right , res , "//" , operator .floordiv )
817
864
818
865
819
- @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , xps .numeric_dtypes () ))
866
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater" , dh .numeric_dtypes ))
820
867
@given (data = st .data ())
821
868
def test_greater (ctx , data ):
822
869
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -836,9 +883,7 @@ def test_greater(ctx, data):
836
883
)
837
884
838
885
839
- @pytest .mark .parametrize (
840
- "ctx" , make_binary_params ("greater_equal" , xps .numeric_dtypes ())
841
- )
886
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("greater_equal" , dh .numeric_dtypes ))
842
887
@given (data = st .data ())
843
888
def test_greater_equal (ctx , data ):
844
889
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -882,7 +927,7 @@ def test_isnan(x):
882
927
unary_assert_against_refimpl ("isnan" , x , out , math .isnan , res_stype = bool )
883
928
884
929
885
- @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , xps .numeric_dtypes () ))
930
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less" , dh .numeric_dtypes ))
886
931
@given (data = st .data ())
887
932
def test_less (ctx , data ):
888
933
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -902,7 +947,7 @@ def test_less(ctx, data):
902
947
)
903
948
904
949
905
- @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , xps .numeric_dtypes () ))
950
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("less_equal" , dh .numeric_dtypes ))
906
951
@given (data = st .data ())
907
952
def test_less_equal (ctx , data ):
908
953
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1014,7 +1059,7 @@ def test_logical_xor(x1, x2):
1014
1059
)
1015
1060
1016
1061
1017
- @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , xps .numeric_dtypes () ))
1062
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , dh .numeric_dtypes ))
1018
1063
@given (data = st .data ())
1019
1064
def test_multiply (ctx , data ):
1020
1065
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1047,7 +1092,7 @@ def test_negative(ctx, data):
1047
1092
)
1048
1093
1049
1094
1050
- @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , xps . scalar_dtypes () ))
1095
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , dh . all_dtypes ))
1051
1096
@given (data = st .data ())
1052
1097
def test_not_equal (ctx , data ):
1053
1098
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1079,7 +1124,7 @@ def test_positive(ctx, data):
1079
1124
ph .assert_array (ctx .func_name , out , x )
1080
1125
1081
1126
1082
- @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , xps .numeric_dtypes () ))
1127
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , dh .numeric_dtypes ))
1083
1128
@given (data = st .data ())
1084
1129
def test_pow (ctx , data ):
1085
1130
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1103,7 +1148,7 @@ def test_pow(ctx, data):
1103
1148
)
1104
1149
1105
1150
1106
- @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , xps .numeric_dtypes () ))
1151
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , dh .numeric_dtypes ))
1107
1152
@given (data = st .data ())
1108
1153
def test_remainder (ctx , data ):
1109
1154
left = data .draw (ctx .left_strat , label = ctx .left_sym )
@@ -1174,7 +1219,7 @@ def test_sqrt(x):
1174
1219
)
1175
1220
1176
1221
1177
- @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , xps .numeric_dtypes () ))
1222
+ @pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , dh .numeric_dtypes ))
1178
1223
@given (data = st .data ())
1179
1224
def test_subtract (ctx , data ):
1180
1225
left = data .draw (ctx .left_strat , label = ctx .left_sym )
0 commit comments