Skip to content

Commit e15c14c

Browse files
committed
MAINT: use normalizer/ArrayLike in _funcs
1 parent 6744787 commit e15c14c

File tree

1 file changed

+68
-69
lines changed

1 file changed

+68
-69
lines changed

torch_np/_funcs.py

Lines changed: 68 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -137,31 +137,29 @@ def clip(a : ArrayLike, min : Optional[ArrayLike]=None, max : Optional[ArrayLike
137137
return _helpers.result_or_out(result, out)
138138

139139

140-
def repeat(a, repeats, axis=None):
141-
tensor, t_repeats = _helpers.to_tensors(a, repeats) # XXX: scalar repeats
142-
result = torch.repeat_interleave(tensor, t_repeats, axis)
140+
@normalizer
141+
def repeat(a : ArrayLike, repeats: ArrayLike, axis=None):
142+
# XXX: scalar repeats; ArrayLikeOrScalar ?
143+
result = torch.repeat_interleave(a, repeats, axis)
143144
return _helpers.array_from(result)
144145

145146

146147
# ### diag et al ###
147148

148-
149-
def diagonal(a, offset=0, axis1=0, axis2=1):
150-
(tensor,) = _helpers.to_tensors(a)
151-
result = _impl.diagonal(tensor, offset, axis1, axis2)
149+
@normalizer
150+
def diagonal(a : ArrayLike, offset=0, axis1=0, axis2=1):
151+
result = _impl.diagonal(a, offset, axis1, axis2)
152152
return _helpers.array_from(result)
153153

154154

155155
@normalizer
156156
def trace(a: ArrayLike, offset=0, axis1=0, axis2=1, dtype: DTypeLike = None, out=None):
157-
# (tensor,) = _helpers.to_tensors(a)
158157
result = _impl.trace(a, offset, axis1, axis2, dtype)
159158
return _helpers.result_or_out(result, out)
160159

161160

162161
@normalizer
163162
def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike = None):
164-
# _util.subok_not_ok(like)
165163
if order != "C":
166164
raise NotImplementedError
167165
result = _impl.eye(N, M, k, dtype)
@@ -170,20 +168,19 @@ def eye(N, M=None, k=0, dtype: DTypeLike = float, order="C", *, like: SubokLike
170168

171169
@normalizer
172170
def identity(n, dtype: DTypeLike = None, *, like: SubokLike = None):
173-
## _util.subok_not_ok(like)
174171
result = torch.eye(n, dtype=dtype)
175172
return _helpers.array_from(result)
176173

177174

178-
def diag(v, k=0):
179-
(tensor,) = _helpers.to_tensors(v)
180-
result = torch.diag(tensor, k)
175+
@normalizer
176+
def diag(v : ArrayLike, k=0):
177+
result = torch.diag(v, k)
181178
return _helpers.array_from(result)
182179

183180

184-
def diagflat(v, k=0):
185-
(tensor,) = _helpers.to_tensors(v)
186-
result = torch.diagflat(tensor, k)
181+
@normalizer
182+
def diagflat(v : ArrayLike, k=0):
183+
result = torch.diagflat(v, k)
187184
return _helpers.array_from(result)
188185

189186

@@ -192,124 +189,126 @@ def diag_indices(n, ndim=2):
192189
return _helpers.tuple_arrays_from(result)
193190

194191

195-
def diag_indices_from(arr):
196-
(tensor,) = _helpers.to_tensors(arr)
197-
result = _impl.diag_indices_from(tensor)
192+
@normalizer
193+
def diag_indices_from(arr : ArrayLike):
194+
result = _impl.diag_indices_from(arr)
198195
return _helpers.tuple_arrays_from(result)
199196

200197

201-
def fill_diagonal(a, val, wrap=False):
202-
tensor, t_val = _helpers.to_tensors(a, val)
203-
result = _impl.fill_diagonal(tensor, t_val, wrap)
198+
@normalizer
199+
def fill_diagonal(a : ArrayLike, val : ArrayLike, wrap=False):
200+
result = _impl.fill_diagonal(a, val, wrap)
204201
return _helpers.array_from(result)
205202

206203

207-
def vdot(a, b, /):
208-
t_a, t_b = _helpers.to_tensors(a, b)
209-
result = _impl.vdot(t_a, t_b)
204+
@normalizer
205+
def vdot(a : ArrayLike, b : ArrayLike, /):
206+
# t_a, t_b = _helpers.to_tensors(a, b)
207+
result = _impl.vdot(a, b)
210208
return result.item()
211209

212210

213-
def dot(a, b, out=None):
214-
t_a, t_b = _helpers.to_tensors(a, b)
215-
result = _impl.dot(t_a, t_b)
211+
@normalizer
212+
def dot(a : ArrayLike, b : ArrayLike, out=None):
213+
# t_a, t_b = _helpers.to_tensors(a, b)
214+
result = _impl.dot(a, b)
216215
return _helpers.result_or_out(result, out)
217216

218217

219218
# ### sort and partition ###
220219

221220

222-
def sort(a, axis=-1, kind=None, order=None):
223-
(tensor,) = _helpers.to_tensors(a)
224-
result = _impl.sort(tensor, axis, kind, order)
221+
@normalizer
222+
def sort(a : ArrayLike, axis=-1, kind=None, order=None):
223+
result = _impl.sort(a, axis, kind, order)
225224
return _helpers.array_from(result)
226225

227226

228-
def argsort(a, axis=-1, kind=None, order=None):
229-
(tensor,) = _helpers.to_tensors(a)
230-
result = _impl.argsort(tensor, axis, kind, order)
227+
@normalizer
228+
def argsort(a : ArrayLike, axis=-1, kind=None, order=None):
229+
result = _impl.argsort(a, axis, kind, order)
231230
return _helpers.array_from(result)
232231

233232

234-
def searchsorted(a, v, side="left", sorter=None):
235-
a_t, v_t, sorter_t = _helpers.to_tensors_or_none(a, v, sorter)
236-
result = torch.searchsorted(a_t, v_t, side=side, sorter=sorter_t)
233+
@normalizer
234+
def searchsorted(a : ArrayLike, v : ArrayLike, side="left", sorter : Optional[ArrayLike]=None):
235+
result = torch.searchsorted(a, v, side=side, sorter=sorter)
237236
return _helpers.array_from(result)
238237

239238

240239
# ### swap/move/roll axis ###
241240

242241

243-
def moveaxis(a, source, destination):
244-
(tensor,) = _helpers.to_tensors(a)
245-
result = _impl.moveaxis(tensor, source, destination)
242+
@normalizer
243+
def moveaxis(a : ArrayLike, source, destination):
244+
result = _impl.moveaxis(a, source, destination)
246245
return _helpers.array_from(result)
247246

248247

249-
def swapaxes(a, axis1, axis2):
250-
(tensor,) = _helpers.to_tensors(a)
251-
result = _flips.swapaxes(tensor, axis1, axis2)
248+
@normalizer
249+
def swapaxes(a : ArrayLike, axis1, axis2):
250+
result = _flips.swapaxes(a, axis1, axis2)
252251
return _helpers.array_from(result)
253252

254253

255-
def rollaxis(a, axis, start=0):
256-
(tensor,) = _helpers.to_tensors(a)
254+
@normalizer
255+
def rollaxis(a : ArrayLike, axis, start=0):
257256
result = _flips.rollaxis(a, axis, start)
258257
return _helpers.array_from(result)
259258

260259

261260
# ### shape manipulations ###
262261

263262

264-
def squeeze(a, axis=None):
265-
(tensor,) = _helpers.to_tensors(a)
266-
result = _impl.squeeze(tensor, axis)
263+
@normalizer
264+
def squeeze(a : ArrayLike, axis=None):
265+
result = _impl.squeeze(a, axis)
267266
return _helpers.array_from(result, a)
268267

269268

270-
def reshape(a, newshape, order="C"):
271-
(tensor,) = _helpers.to_tensors(a)
272-
result = _impl.reshape(tensor, newshape, order=order)
269+
@normalizer
270+
def reshape(a : ArrayLike, newshape, order="C"):
271+
result = _impl.reshape(a, newshape, order=order)
273272
return _helpers.array_from(result, a)
274273

275274

276-
def transpose(a, axes=None):
277-
(tensor,) = _helpers.to_tensors(a)
278-
result = _impl.transpose(tensor, axes)
275+
@normalizer
276+
def transpose(a : ArrayLike, axes=None):
277+
result = _impl.transpose(a, axes)
279278
return _helpers.array_from(result, a)
280279

281280

282-
def ravel(a, order="C"):
283-
(tensor,) = _helpers.to_tensors(a)
284-
result = _impl.ravel(tensor)
281+
@normalizer
282+
def ravel(a : ArrayLike, order="C"):
283+
result = _impl.ravel(a)
285284
return _helpers.array_from(result, a)
286285

287286

288287
# leading underscore since arr.flatten exists but np.flatten does not
289-
def _flatten(a, order="C"):
290-
(tensor,) = _helpers.to_tensors(a)
291-
result = _impl._flatten(tensor)
288+
@normalizer
289+
def _flatten(a : ArrayLike, order="C"):
290+
result = _impl._flatten(a)
292291
return _helpers.array_from(result, a)
293292

294293

295294
# ### Type/shape etc queries ###
296295

297296

298-
def real(a):
299-
(tensor,) = _helpers.to_tensors(a)
300-
result = torch.real(tensor)
297+
@normalizer
298+
def real(a : ArrayLike):
299+
result = torch.real(a)
301300
return _helpers.array_from(result)
302301

303302

304-
def imag(a):
305-
(tensor,) = _helpers.to_tensors(a)
306-
result = _impl.imag(tensor)
303+
@normalizer
304+
def imag(a: ArrayLike):
305+
result = _impl.imag(a)
307306
return _helpers.array_from(result)
308307

309308

310-
def round_(a, decimals=0, out=None):
311-
(tensor,) = _helpers.to_tensors(a)
312-
result = _impl.round(tensor, decimals)
309+
@normalizer
310+
def round_(a : ArrayLike, decimals=0, out=None):
311+
result = _impl.round(a, decimals)
313312
return _helpers.result_or_out(result, out)
314313

315314

0 commit comments

Comments
 (0)