9
9
from hypothesis .errors import HypothesisWarning
10
10
from hypothesis .extra .array_api import make_strategies_namespace
11
11
12
- import torch_np
12
+ import torch_np as np
13
13
14
14
__all__ = ["xps" ]
15
15
16
16
with warnings .catch_warnings ():
17
17
warnings .filterwarnings ("ignore" , category = HypothesisWarning )
18
- xps = make_strategies_namespace (torch_np , api_version = "2021.12" )
18
+ np .bool = np .bool_
19
+ xps = make_strategies_namespace (np , api_version = "2021.12" )
19
20
20
21
21
22
def integer_array_indices (shape , result_shape ) -> st .SearchStrategy [tuple ]:
@@ -40,3 +41,29 @@ def test_integer_indexing(x, data):
40
41
idx = data .draw (integer_array_indices (x .shape , result_shape ), label = "idx" )
41
42
result = x [idx ]
42
43
assert result .shape == result_shape
44
+
45
+
46
+ @given (shape = xps .array_shapes (), data = st .data ())
47
+ def test_full (shape , data ):
48
+ if data .draw (st .booleans (), label = "pass kwargs?" ):
49
+ kw = {}
50
+ else :
51
+ dtype = data .draw (st .none () | xps .scalar_dtypes (), label = "dtype" )
52
+ kw = {"dtype" : dtype }
53
+ _dtype = kw .get ("dtype" , None ) or data .draw (
54
+ st .sampled_from ([np .bool , np .int64 , np .float64 ]), label = "_dtype"
55
+ )
56
+ fill_value = data .draw (xps .from_dtype (_dtype ), label = "fill_value" )
57
+ out = np .full (shape , fill_value , ** kw )
58
+ if kw .get ("dtype" , None ) is None :
59
+ if isinstance (fill_value , bool ):
60
+ assert out .dtype == np .bool
61
+ elif isinstance (fill_value , int ):
62
+ assert out .dtype == np .int64
63
+ else :
64
+ assert isinstance (fill_value , float ) # sanity check
65
+ assert out .dtype == np .float64
66
+ else :
67
+ assert out .dtype == kw ["dtype" ]
68
+ assert out .shape == shape
69
+ assert (out == fill_value ).all ()
0 commit comments