@@ -151,7 +151,7 @@ def dstack(tup, *, dtype=None, casting="same_kind"):
151
151
# but {h,v}stack do. Hence add them here for consistency.
152
152
tensors = _helpers .to_tensors (* tup )
153
153
result = _impl .dstack (tensors , dtype = dtype , casting = casting )
154
- return asarray (result )
154
+ return asarray (result )
155
155
156
156
157
157
@_decorators .dtype_to_torch
@@ -257,49 +257,27 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
257
257
return asarray (torch .linspace (start , stop , num , dtype = dtype ))
258
258
259
259
260
+ @_decorators .dtype_to_torch
260
261
def geomspace (start , stop , num = 50 , endpoint = True , dtype = None , axis = 0 ):
261
262
if axis != 0 or not endpoint :
262
263
raise NotImplementedError
263
- tstart , tstop = torch .as_tensor ([start , stop ])
264
- base = torch .pow (tstop / tstart , 1.0 / (num - 1 ))
265
- result = torch .logspace (
266
- torch .log (tstart ) / torch .log (base ),
267
- torch .log (tstop ) / torch .log (base ),
268
- num ,
269
- base = base ,
270
- )
264
+ result = _impl .geomspace (start , stop , num , endpoint , dtype , axis )
271
265
return asarray (result )
272
266
273
267
268
+ @_decorators .dtype_to_torch
274
269
def logspace (start , stop , num = 50 , endpoint = True , base = 10.0 , dtype = None , axis = 0 ):
275
270
if axis != 0 or not endpoint :
276
271
raise NotImplementedError
277
272
return asarray (torch .logspace (start , stop , num , base = base , dtype = dtype ))
278
273
279
274
275
+ @_decorators .dtype_to_torch
280
276
def arange (start = None , stop = None , step = 1 , dtype = None , * , like = None ):
281
277
_util .subok_not_ok (like )
282
- if step == 0 :
283
- raise ZeroDivisionError
284
- if stop is None and start is None :
285
- raise TypeError
286
- if stop is None :
287
- # XXX: this breaks if start is passed as a kwarg:
288
- # arange(start=4) should raise (no stop) but doesn't
289
- start , stop = 0 , start
290
- if start is None :
291
- start = 0
292
-
293
- if dtype is None :
294
- dtype = _dtypes .default_int_type ()
295
- dtype = result_type (start , stop , step , dtype )
296
- torch_dtype = _dtypes .torch_dtype_from (dtype )
297
278
start , stop , step = _helpers .ndarrays_to_tensors (start , stop , step )
298
-
299
- try :
300
- return asarray (torch .arange (start , stop , step , dtype = torch_dtype ))
301
- except RuntimeError :
302
- raise ValueError ("Maximum allowed size exceeded" )
279
+ result = _impl .arange (start , stop , step , dtype = dtype )
280
+ return asarray (result )
303
281
304
282
305
283
@_decorators .dtype_to_torch
@@ -316,14 +294,12 @@ def empty(shape, dtype=float, order="C", *, like=None):
316
294
# NB: *_like function deliberately deviate from numpy: it has subok=True
317
295
# as the default; we set subok=False and raise on anything else.
318
296
@asarray_replacer ()
297
+ @_decorators .dtype_to_torch
319
298
def empty_like (prototype , dtype = None , order = "K" , subok = False , shape = None ):
320
299
_util .subok_not_ok (subok = subok )
321
300
if order != "K" :
322
301
raise NotImplementedError
323
- torch_dtype = None if dtype is None else _dtypes .torch_dtype_from (dtype )
324
- result = torch .empty_like (prototype , dtype = torch_dtype )
325
- if shape is not None :
326
- result = result .reshape (shape )
302
+ result = _impl .empty_like (prototype , dtype = dtype , shape = shape )
327
303
return result
328
304
329
305
@@ -332,28 +308,18 @@ def full(shape, fill_value, dtype=None, order="C", *, like=None):
332
308
_util .subok_not_ok (like )
333
309
if order != "C" :
334
310
raise NotImplementedError
335
-
336
311
fill_value = asarray (fill_value ).get ()
337
- if dtype is None :
338
- dtype = fill_value .dtype
339
-
340
- if not isinstance (shape , (tuple , list )):
341
- shape = (shape ,)
342
-
343
- result = torch .full (shape , fill_value , dtype = dtype )
344
-
312
+ result = _impl .full (shape , fill_value , dtype = dtype )
345
313
return asarray (result )
346
314
347
315
348
316
@asarray_replacer ()
317
+ @_decorators .dtype_to_torch
349
318
def full_like (a , fill_value , dtype = None , order = "K" , subok = False , shape = None ):
350
319
_util .subok_not_ok (subok = subok )
351
320
if order != "K" :
352
321
raise NotImplementedError
353
- torch_dtype = None if dtype is None else _dtypes .torch_dtype_from (dtype )
354
- result = torch .full_like (a , fill_value , dtype = torch_dtype )
355
- if shape is not None :
356
- result = result .reshape (shape )
322
+ result = _impl .full_like (a , fill_value , dtype = dtype , shape = shape )
357
323
return result
358
324
359
325
@@ -369,14 +335,12 @@ def ones(shape, dtype=None, order="C", *, like=None):
369
335
370
336
371
337
@asarray_replacer ()
338
+ @_decorators .dtype_to_torch
372
339
def ones_like (a , dtype = None , order = "K" , subok = False , shape = None ):
373
340
_util .subok_not_ok (subok = subok )
374
341
if order != "K" :
375
342
raise NotImplementedError
376
- torch_dtype = None if dtype is None else _dtypes .torch_dtype_from (dtype )
377
- result = torch .ones_like (a , dtype = torch_dtype )
378
- if shape is not None :
379
- result = result .reshape (shape )
343
+ result = _impl .ones_like (a , dtype = dtype , shape = shape )
380
344
return result
381
345
382
346
@@ -392,14 +356,12 @@ def zeros(shape, dtype=None, order="C", *, like=None):
392
356
393
357
394
358
@asarray_replacer ()
359
+ @_decorators .dtype_to_torch
395
360
def zeros_like (a , dtype = None , order = "K" , subok = False , shape = None ):
396
361
_util .subok_not_ok (subok = subok )
397
362
if order != "K" :
398
363
raise NotImplementedError
399
- torch_dtype = None if dtype is None else _dtypes .torch_dtype_from (dtype )
400
- result = torch .zeros_like (a , dtype = torch_dtype )
401
- if shape is not None :
402
- result = result .reshape (shape )
364
+ result = _impl .zeros_like (a , dtype = dtype , shape = shape )
403
365
return result
404
366
405
367
@@ -408,11 +370,8 @@ def eye(N, M=None, k=0, dtype=float, order="C", *, like=None):
408
370
_util .subok_not_ok (like )
409
371
if order != "C" :
410
372
raise NotImplementedError
411
- if M is None :
412
- M = N
413
- z = torch .zeros (N , M , dtype = dtype )
414
- z .diagonal (k ).fill_ (1 )
415
- return asarray (z )
373
+ result = _impl .eye (N , M , k , dtype )
374
+ return asarray (result )
416
375
417
376
418
377
def identity (n , dtype = None , * , like = None ):
0 commit comments