10
10
11
11
pytest .importorskip ("hypothesis" )
12
12
13
- from hypothesis import given
13
+ import numpy as np
14
+ import torch
15
+ from hypothesis import given , note
14
16
from hypothesis import strategies as st
15
17
from hypothesis .errors import HypothesisWarning
18
+ from hypothesis .extra import numpy as nps
16
19
from hypothesis .extra .array_api import make_strategies_namespace
17
20
18
- import torch_np as np
21
+ import torch_np as tnp
22
+ from torch_np ._dtypes import sctypes
23
+ from torch_np .testing import assert_array_equal
19
24
20
25
__all__ = ["xps" ]
21
26
22
27
with warnings .catch_warnings ():
23
28
warnings .filterwarnings ("ignore" , category = HypothesisWarning )
24
- np .bool = np .bool_
25
- xps = make_strategies_namespace (np , api_version = "2022.12" )
29
+ tnp .bool = tnp .bool_
30
+ xps = make_strategies_namespace (tnp , api_version = "2022.12" )
26
31
27
32
28
- default_dtypes = [np .bool , np .int64 , np .float64 , np .complex128 ]
33
+ default_dtypes = [tnp .bool , tnp .int64 , tnp .float64 , tnp .complex128 ]
29
34
kind_to_strat = {
30
35
"b" : xps .boolean_dtypes (),
31
36
"i" : xps .integer_dtypes (),
32
37
"u" : xps .unsigned_integer_dtypes (sizes = 8 ),
33
38
"f" : xps .floating_dtypes (),
34
39
"c" : xps .complex_dtypes (),
35
40
}
36
- scalar_dtype_strat = st .one_of (kind_to_strat .values ()).map (np .dtype )
41
+ scalar_dtype_strat = st .one_of (kind_to_strat .values ()).map (tnp .dtype )
37
42
38
43
39
44
@pytest .mark .skip (reason = "flaky" )
@@ -55,14 +60,14 @@ def test_full(shape, data):
55
60
else :
56
61
values_dtypes_strat = st .just (_dtype )
57
62
values_strat = values_dtypes_strat .flatmap (
58
- lambda d : values_strat .map (lambda v : np .asarray (v , dtype = d ))
63
+ lambda d : values_strat .map (lambda v : tnp .asarray (v , dtype = d ))
59
64
)
60
65
fill_value = data .draw (values_strat , label = "fill_value" )
61
- out = np .full (shape , fill_value , ** kw )
66
+ out = tnp .full (shape , fill_value , ** kw )
62
67
assert out .dtype == _dtype
63
68
assert out .shape == shape
64
69
if cmath .isnan (fill_value ):
65
- assert np .isnan (out ).all ()
70
+ assert tnp .isnan (out ).all ()
66
71
else :
67
72
assert (out == fill_value ).all ()
68
73
@@ -89,3 +94,48 @@ def test_integer_indexing(x, data):
89
94
idx = data .draw (integer_array_indices (x .shape , result_shape ), label = "idx" )
90
95
result = x [idx ]
91
96
assert result .shape == result_shape
97
+
98
+
99
+ @given (
100
+ np_x = nps .arrays (
101
+ # We specifically use namespaced dtypes to prevent non-native byte-order issues
102
+ dtype = scalar_dtype_strat .map (lambda d : getattr (np , d .name )),
103
+ shape = nps .array_shapes (),
104
+ ),
105
+ data = st .data (),
106
+ )
107
+ def test_put (np_x , data ):
108
+ # We cast arrays from torch_np.asarray as currently it doesn't carry over
109
+ # dtypes. XXX: Remove the below sanity check and subsequent casting when
110
+ # this is fixed.
111
+ assert tnp .asarray (np .zeros (5 , dtype = np .int16 )).dtype != tnp .int16
112
+
113
+ tnp_x = tnp .asarray (np_x .copy ()).astype (np_x .dtype .name )
114
+
115
+ result_shapes = st .shared (nps .array_shapes ())
116
+ ind = data .draw (
117
+ nps .integer_array_indices (np_x .shape , result_shape = result_shapes ), label = "ind"
118
+ )
119
+ v = data .draw (nps .arrays (dtype = np_x .dtype , shape = result_shapes ), label = "v" )
120
+
121
+ tnp_x_copy = tnp_x .copy ()
122
+ np .put (np_x , ind , v )
123
+ note (f"(after put) { np_x = } " )
124
+ assert_array_equal (tnp_x , tnp_x_copy ) # sanity check
125
+
126
+ note (f"{ tnp_x = } " )
127
+ tnp_ind = []
128
+ for np_indices in ind :
129
+ tnp_indices = tnp .asarray (np_indices ).astype (np_indices .dtype .name )
130
+ tnp_ind .append (tnp_indices )
131
+ tnp_ind = tuple (tnp_ind )
132
+ note (f"{ tnp_ind = } " )
133
+ tnp_v = tnp .asarray (v .copy ()).astype (v .dtype .name )
134
+ note (f"{ tnp_v = } " )
135
+ try :
136
+ tnp .put (tnp_x , tnp_ind , tnp_v )
137
+ except NotImplementedError :
138
+ return
139
+ note (f"(after put) { tnp_x = } " )
140
+
141
+ assert_array_equal (tnp_x , tnp .asarray (np_x ).astype (tnp_x .dtype ))
0 commit comments