@@ -282,12 +282,20 @@ def arange(start=None, stop=None, step=1, dtype=None, *, like=None):
282
282
raise ValueError ("Maximum allowed size exceeded" )
283
283
284
284
285
+ @_decorators .dtype_to_torch
285
286
def empty (shape , dtype = float , order = "C" , * , like = None ):
286
287
_util .subok_not_ok (like )
287
288
if order != "C" :
288
289
raise NotImplementedError
289
- torch_dtype = _dtypes .torch_dtype_from (dtype )
290
- return asarray (torch .empty (shape , dtype = torch_dtype ))
290
+
291
+ if dtype is None :
292
+ from ._detail ._scalar_types import default_float_type
293
+
294
+ dtype = default_float_type .torch_dtype
295
+
296
+ result = torch .empty (shape , dtype = dtype )
297
+
298
+ return asarray (result )
291
299
292
300
293
301
# NB: *_like function deliberately deviate from numpy: it has subok=True
@@ -303,15 +311,22 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
303
311
result = result .reshape (shape )
304
312
return result
305
313
314
+
306
315
@_decorators .dtype_to_torch
307
316
def full (shape , fill_value , dtype = None , order = "C" , * , like = None ):
308
317
_util .subok_not_ok (like )
309
318
if order != "C" :
310
319
raise NotImplementedError
320
+
311
321
fill_value = asarray (fill_value ).get ()
312
322
if dtype is None :
313
- dtype = fill_value .dtype
323
+ dtype = fill_value .dtype
324
+
325
+ if not isinstance (shape , (tuple , list )):
326
+ shape = (shape ,)
327
+
314
328
result = torch .full (shape , fill_value , dtype = dtype )
329
+
315
330
return asarray (result )
316
331
317
332
@@ -327,12 +342,19 @@ def full_like(a, fill_value, dtype=None, order="K", subok=False, shape=None):
327
342
return result
328
343
329
344
345
+ @_decorators .dtype_to_torch
330
346
def ones (shape , dtype = None , order = "C" , * , like = None ):
331
347
_util .subok_not_ok (like )
332
348
if order != "C" :
333
349
raise NotImplementedError
334
- torch_dtype = _dtypes .torch_dtype_from (dtype )
335
- return asarray (torch .ones (shape , dtype = torch_dtype ))
350
+ if dtype is None :
351
+ from ._detail ._scalar_types import default_float_type
352
+
353
+ dtype = default_float_type .torch_dtype
354
+
355
+ result = torch .ones (shape , dtype = dtype )
356
+
357
+ return asarray (result )
336
358
337
359
338
360
@asarray_replacer ()
@@ -354,7 +376,8 @@ def zeros(shape, dtype=None, order="C", *, like=None):
354
376
raise NotImplementedError
355
377
if dtype is None :
356
378
dtype = _dtypes_impl .default_float_dtype
357
- return asarray (torch .zeros (shape , dtype = dtype ))
379
+ result = torch .zeros (shape , dtype = dtype )
380
+ return asarray (result )
358
381
359
382
360
383
@asarray_replacer ()
0 commit comments