@@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
30
30
return xps .boolean_dtypes () | all_integer_dtypes ()
31
31
32
32
33
+ class OnewayBroadcastableShapes (NamedTuple ):
34
+ input_shape : Shape
35
+ result_shape : Shape
36
+
37
+
38
+ @st .composite
39
+ def oneway_broadcastable_shapes (draw ) -> st .SearchStrategy [OnewayBroadcastableShapes ]:
40
+ """Return a strategy for input shapes that broadcast to result shapes."""
41
+ result_shape = draw (hh .shapes (min_side = 1 ))
42
+ input_shape = draw (
43
+ xps .broadcastable_shapes (
44
+ result_shape ,
45
+ # Override defaults so bad shapes are less likely to be generated.
46
+ max_side = None if result_shape == () else max (result_shape ),
47
+ max_dims = len (result_shape ),
48
+ ).filter (lambda s : sh .broadcast_shapes (result_shape , s ) == result_shape )
49
+ )
50
+ return OnewayBroadcastableShapes (input_shape , result_shape )
51
+
52
+
33
53
def mock_int_dtype (n : int , dtype : DataType ) -> int :
34
54
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
35
55
nbits = dh .dtype_nbits [dtype ]
@@ -326,9 +346,15 @@ def make_param(
326
346
)
327
347
else :
328
348
if func_type is FuncType .IOP :
329
- shared_shapes = st .shared (hh .shapes (** shapes_kw ))
330
- left_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
331
- right_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
349
+ shared_oneway_shapes = st .shared (oneway_broadcastable_shapes ())
350
+ left_strat = xps .arrays (
351
+ dtype = shared_dtypes ,
352
+ shape = shared_oneway_shapes .map (lambda S : S .result_shape ),
353
+ )
354
+ right_strat = xps .arrays (
355
+ dtype = shared_dtypes ,
356
+ shape = shared_oneway_shapes .map (lambda S : S .input_shape ),
357
+ )
332
358
else :
333
359
mutual_shapes = st .shared (
334
360
hh .mutually_broadcastable_shapes (2 , ** shapes_kw )
0 commit comments