@@ -25,15 +25,14 @@ def nonzero(a: ArrayLike):
25
25
26
26
27
27
@normalizer
28
- def argwhere (a : ArrayLike ):
28
+ def argwhere (a : ArrayLike ) -> NDArray :
29
29
result = torch .argwhere (a )
30
- return _helpers .array_from (result )
31
-
30
+ return result
32
31
33
32
@normalizer
34
- def flatnonzero (a : ArrayLike ):
33
+ def flatnonzero (a : ArrayLike ) -> NDArray :
35
34
result = a .ravel ().nonzero (as_tuple = True )[0 ]
36
- return _helpers . array_from ( result )
35
+ return result
37
36
38
37
39
38
@normalizer
@@ -50,25 +49,24 @@ def clip(
50
49
51
50
52
51
@normalizer
53
- def repeat (a : ArrayLike , repeats : ArrayLike , axis = None ):
54
- # XXX: scalar repeats; ArrayLikeOrScalar ?
52
+ def repeat (a : ArrayLike , repeats : ArrayLike , axis = None ) -> NDArray :
55
53
result = torch .repeat_interleave (a , repeats , axis )
56
- return _helpers . array_from ( result )
54
+ return result
57
55
58
56
59
57
@normalizer
60
- def tile (A : ArrayLike , reps ):
58
+ def tile (A : ArrayLike , reps ) -> NDArray :
61
59
result = _impl .tile (A , reps )
62
- return _helpers . array_from ( result )
60
+ return result
63
61
64
62
65
63
# ### diag et al ###
66
64
67
65
68
66
@normalizer
69
- def diagonal (a : ArrayLike , offset = 0 , axis1 = 0 , axis2 = 1 ):
67
+ def diagonal (a : ArrayLike , offset = 0 , axis1 = 0 , axis2 = 1 ) -> NDArray :
70
68
result = _impl .diagonal (a , offset , axis1 , axis2 )
71
- return _helpers . array_from ( result )
69
+ return result
72
70
73
71
74
72
@normalizer
@@ -85,29 +83,29 @@ def trace(
85
83
86
84
87
85
@normalizer
88
- def eye (N , M = None , k = 0 , dtype : DTypeLike = float , order = "C" , * , like : SubokLike = None ):
86
+ def eye (N , M = None , k = 0 , dtype : DTypeLike = float , order = "C" , * , like : SubokLike = None ) -> NDArray :
89
87
if order != "C" :
90
88
raise NotImplementedError
91
89
result = _impl .eye (N , M , k , dtype )
92
- return _helpers . array_from ( result )
90
+ return result
93
91
94
92
95
93
@normalizer
96
- def identity (n , dtype : DTypeLike = None , * , like : SubokLike = None ):
94
+ def identity (n , dtype : DTypeLike = None , * , like : SubokLike = None ) -> NDArray :
97
95
result = torch .eye (n , dtype = dtype )
98
- return _helpers . array_from ( result )
96
+ return result
99
97
100
98
101
99
@normalizer
102
- def diag (v : ArrayLike , k = 0 ):
100
+ def diag (v : ArrayLike , k = 0 ) -> NDArray :
103
101
result = torch .diag (v , k )
104
- return _helpers . array_from ( result )
102
+ return result
105
103
106
104
107
105
@normalizer
108
- def diagflat (v : ArrayLike , k = 0 ):
106
+ def diagflat (v : ArrayLike , k = 0 ) -> NDArray :
109
107
result = torch .diagflat (v , k )
110
- return _helpers . array_from ( result )
108
+ return result
111
109
112
110
113
111
def diag_indices (n , ndim = 2 ):
@@ -122,9 +120,9 @@ def diag_indices_from(arr: ArrayLike):
122
120
123
121
124
122
@normalizer
125
- def fill_diagonal (a : ArrayLike , val : ArrayLike , wrap = False ):
123
+ def fill_diagonal (a : ArrayLike , val : ArrayLike , wrap = False ) -> NDArray :
126
124
result = _impl .fill_diagonal (a , val , wrap )
127
- return _helpers . array_from ( result )
125
+ return result
128
126
129
127
130
128
@normalizer
@@ -143,93 +141,93 @@ def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
143
141
144
142
145
143
@normalizer
146
- def sort (a : ArrayLike , axis = - 1 , kind = None , order = None ):
144
+ def sort (a : ArrayLike , axis = - 1 , kind = None , order = None ) -> NDArray :
147
145
result = _impl .sort (a , axis , kind , order )
148
- return _helpers . array_from ( result )
146
+ return result
149
147
150
148
151
149
@normalizer
152
- def argsort (a : ArrayLike , axis = - 1 , kind = None , order = None ):
150
+ def argsort (a : ArrayLike , axis = - 1 , kind = None , order = None ) -> NDArray :
153
151
result = _impl .argsort (a , axis , kind , order )
154
- return _helpers . array_from ( result )
152
+ return result
155
153
156
154
157
155
@normalizer
158
156
def searchsorted (
159
157
a : ArrayLike , v : ArrayLike , side = "left" , sorter : Optional [ArrayLike ] = None
160
- ):
158
+ ) -> NDArray :
161
159
result = torch .searchsorted (a , v , side = side , sorter = sorter )
162
- return _helpers . array_from ( result )
160
+ return result
163
161
164
162
165
163
# ### swap/move/roll axis ###
166
164
167
165
168
166
@normalizer
169
- def moveaxis (a : ArrayLike , source , destination ):
167
+ def moveaxis (a : ArrayLike , source , destination ) -> NDArray :
170
168
result = _impl .moveaxis (a , source , destination )
171
- return _helpers . array_from ( result )
169
+ return result
172
170
173
171
174
172
@normalizer
175
- def swapaxes (a : ArrayLike , axis1 , axis2 ):
173
+ def swapaxes (a : ArrayLike , axis1 , axis2 ) -> NDArray :
176
174
result = _impl .swapaxes (a , axis1 , axis2 )
177
- return _helpers . array_from ( result )
175
+ return result
178
176
179
177
180
178
@normalizer
181
- def rollaxis (a : ArrayLike , axis , start = 0 ):
179
+ def rollaxis (a : ArrayLike , axis , start = 0 ) -> NDArray :
182
180
result = _impl .rollaxis (a , axis , start )
183
- return _helpers . array_from ( result )
181
+ return result
184
182
185
183
186
184
# ### shape manipulations ###
187
185
188
186
189
187
@normalizer
190
- def squeeze (a : ArrayLike , axis = None ):
188
+ def squeeze (a : ArrayLike , axis = None ) -> NDArray :
191
189
result = _impl .squeeze (a , axis )
192
- return _helpers . array_from ( result , a )
190
+ return result
193
191
194
192
195
193
@normalizer
196
- def reshape (a : ArrayLike , newshape , order = "C" ):
194
+ def reshape (a : ArrayLike , newshape , order = "C" ) -> NDArray :
197
195
result = _impl .reshape (a , newshape , order = order )
198
- return _helpers . array_from ( result , a )
196
+ return result
199
197
200
198
201
199
@normalizer
202
- def transpose (a : ArrayLike , axes = None ):
200
+ def transpose (a : ArrayLike , axes = None ) -> NDArray :
203
201
result = _impl .transpose (a , axes )
204
- return _helpers . array_from ( result , a )
202
+ return result
205
203
206
204
207
205
@normalizer
208
- def ravel (a : ArrayLike , order = "C" ):
206
+ def ravel (a : ArrayLike , order = "C" ) -> NDArray :
209
207
result = _impl .ravel (a )
210
- return _helpers . array_from ( result , a )
208
+ return result
211
209
212
210
213
211
# leading underscore since arr.flatten exists but np.flatten does not
214
212
@normalizer
215
- def _flatten (a : ArrayLike , order = "C" ):
213
+ def _flatten (a : ArrayLike , order = "C" ) -> NDArray :
216
214
result = _impl ._flatten (a )
217
- return _helpers . array_from ( result , a )
215
+ return result
218
216
219
217
220
218
# ### Type/shape etc queries ###
221
219
222
220
223
221
@normalizer
224
- def real (a : ArrayLike ):
222
+ def real (a : ArrayLike ) -> NDArray :
225
223
result = torch .real (a )
226
- return _helpers . array_from ( result )
224
+ return result
227
225
228
226
229
227
@normalizer
230
- def imag (a : ArrayLike ):
228
+ def imag (a : ArrayLike ) -> NDArray :
231
229
result = _impl .imag (a )
232
- return _helpers . array_from ( result )
230
+ return result
233
231
234
232
235
233
@normalizer
@@ -419,9 +417,9 @@ def any(
419
417
420
418
421
419
@normalizer
422
- def count_nonzero (a : ArrayLike , axis : AxisLike = None , * , keepdims = False ):
420
+ def count_nonzero (a : ArrayLike , axis : AxisLike = None , * , keepdims = False ) -> NDArray :
423
421
result = _impl .count_nonzero (a , axis = axis , keepdims = keepdims )
424
- return _helpers . array_from ( result )
422
+ return result
425
423
426
424
427
425
@normalizer
0 commit comments