7
7
from hypothesis .strategies import (lists , integers , sampled_from ,
8
8
shared , floats , just , composite , one_of ,
9
9
none , booleans )
10
- from hypothesis .extra .array_api import make_strategies_namespace
11
10
12
11
from .pytest_helpers import nargs
13
12
from .array_helpers import (dtype_ranges , integer_dtype_objects ,
17
16
ndindex )
18
17
from ._array_module import (full , float32 , float64 , bool as bool_dtype ,
19
18
_UndefinedStub , eye , broadcast_to )
20
- from . import _array_module
21
19
from . import _array_module as xp
20
+ from . import xps
22
21
23
22
from .function_stubs import elementwise_functions
24
23
25
24
26
- xps = make_strategies_namespace (xp )
27
-
28
-
29
25
# Set this to True to not fail tests just because a dtype isn't implemented.
30
26
# If no compatible dtype is implemented for a given test, the test will fail
31
27
# with a hypothesis health check error. Note that this functionality will not
@@ -79,10 +75,6 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
79
75
dtype_pairs = [(i , j ) for i , j in dtype_pairs if i in dtype_objects and j in dtype_objects ]
80
76
return draw (sampled_from (dtype_pairs ))
81
77
82
- shared_mutually_promotable_dtype_pairs = shared (
83
- mutually_promotable_dtypes (), key = "mutually_promotable_dtype_pair"
84
- )
85
-
86
78
# shared() allows us to draw either the function or the function name and they
87
79
# will both correspond to the same function.
88
80
@@ -93,10 +85,10 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects):
93
85
lambda func_name : nargs (func_name ) > 1 )
94
86
95
87
elementwise_function_objects = elementwise_functions_names .map (
96
- lambda i : getattr (_array_module , i ))
88
+ lambda i : getattr (xp , i ))
97
89
array_functions = elementwise_function_objects
98
90
multiarg_array_functions = multiarg_array_functions_names .map (
99
- lambda i : getattr (_array_module , i ))
91
+ lambda i : getattr (xp , i ))
100
92
101
93
# Limit the total size of an array shape
102
94
MAX_ARRAY_SIZE = 10000
@@ -184,7 +176,6 @@ def two_broadcastable_shapes(draw):
184
176
sizes = integers (0 , MAX_ARRAY_SIZE )
185
177
sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
186
178
187
- # TODO: Generate general arrays here, rather than just scalars.
188
179
numeric_arrays = xps .arrays (
189
180
dtype = shared (xps .floating_dtypes (), key = 'dtypes' ),
190
181
shape = shared (xps .array_shapes (), key = 'shapes' ),
@@ -295,14 +286,18 @@ def multiaxis_indices(draw, shapes):
295
286
return tuple (res )
296
287
297
288
298
- shared_arrays1 = xps .arrays (
299
- dtype = shared_mutually_promotable_dtype_pairs .map (lambda pair : pair [0 ]),
300
- shape = shared (two_mutually_broadcastable_shapes , key = "shape_pair" ).map (lambda pair : pair [0 ]),
301
- )
302
- shared_arrays2 = xps .arrays (
303
- dtype = shared_mutually_promotable_dtype_pairs .map (lambda pair : pair [1 ]),
304
- shape = shared (two_mutually_broadcastable_shapes , key = "shape_pair" ).map (lambda pair : pair [1 ]),
305
- )
289
+ def two_mutual_arrays (dtypes = dtype_objects ):
290
+ mutual_dtypes = shared (mutually_promotable_dtypes (dtypes ))
291
+ mutual_shapes = shared (two_mutually_broadcastable_shapes )
292
+ arrays1 = xps .arrays (
293
+ dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
294
+ shape = mutual_shapes .map (lambda pair : pair [0 ]),
295
+ )
296
+ arrays2 = xps .arrays (
297
+ dtype = mutual_dtypes .map (lambda pair : pair [1 ]),
298
+ shape = mutual_shapes .map (lambda pair : pair [1 ]),
299
+ )
300
+ return arrays1 , arrays2
306
301
307
302
308
303
@composite
0 commit comments