@@ -159,6 +159,9 @@ def __hash__(self):
159
159
def __float__ (self ):
160
160
return float (self ._tensor )
161
161
162
+ def __int__ (self ):
163
+ return int (self ._tensor )
164
+
162
165
# XXX : are single-element ndarrays scalars?
163
166
def is_integer (self ):
164
167
if self .shape == ():
@@ -354,7 +357,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal
354
357
355
358
if dtype is None :
356
359
dtype = self .dtype
357
- if not _dtypes .is_floating (dtype ):
360
+ if _dtypes .is_integer (dtype ):
358
361
dtype = _dtypes .default_float_type ()
359
362
torch_dtype = _dtypes .torch_dtype_from (dtype )
360
363
@@ -374,7 +377,7 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,
374
377
375
378
if dtype is None :
376
379
dtype = self .dtype
377
- if not _dtypes .is_floating (dtype ):
380
+ if _dtypes .is_integer (dtype ):
378
381
dtype = _dtypes .default_float_type ()
379
382
torch_dtype = _dtypes .torch_dtype_from (dtype )
380
383
@@ -396,67 +399,80 @@ def __setitem__(self, index, value):
396
399
return self ._tensor .__setitem__ (index , value )
397
400
398
401
399
- def asarray (a , dtype = None , order = None , * , like = None ):
400
- _util .subok_not_ok (like )
401
- if order is not None :
402
+ # This is the ideally the only place which talks to ndarray directly.
403
+ # The rest goes through asarray (preferred) or array.
404
+
405
+ def array (object , dtype = None , * , copy = True , order = 'K' , subok = False , ndmin = 0 ,
406
+ like = None ):
407
+ _util .subok_not_ok (like , subok )
408
+ if order != 'K' :
402
409
raise NotImplementedError
403
410
404
- if isinstance ( a , ndarray ):
405
- if dtype is not None and dtype != a . dtype :
406
- a = a . astype ( dtype )
407
- return a
411
+ # a happy path
412
+ if isinstance ( object , ndarray ) :
413
+ if copy is False and dtype is None and ndmin <= object . ndim :
414
+ return object
408
415
409
- if isinstance ( a , ( list , tuple )):
410
- # handle lists of ndarrays, [1, [2, 3], ndarray(4)] etc
416
+ # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
417
+ if isinstance ( object , ( list , tuple )):
411
418
a1 = []
412
- for elem in a :
419
+ for elem in object :
413
420
if isinstance (elem , ndarray ):
414
421
a1 .append (elem .get ().tolist ())
415
422
else :
416
423
a1 .append (elem )
424
+ object = a1
425
+
426
+ # get the tensor from "object"
427
+ if isinstance (object , ndarray ):
428
+ tensor = object ._tensor
429
+ base = object
430
+ elif isinstance (object , torch .Tensor ):
431
+ tensor = object
432
+ base = None
417
433
else :
418
- a1 = a
434
+ tensor = torch .as_tensor (object )
435
+ base = None
419
436
420
- torch_dtype = _dtypes .torch_dtype_from (dtype )
437
+ # At this point, `tensor.dtype` is the pytorch default. Our default may
438
+ # differ, so need to typecast. However, we cannot just do `tensor.to`,
439
+ # because if our desired dtype is wider then pytorch's, `tensor`
440
+ # may have lost precision:
421
441
422
- # This and array(...) are the only places which talk to ndarray directly.
423
- # The rest goes through asarray (preferred) or array.
424
- out = ndarray ()
425
- tt = torch .as_tensor (a1 , dtype = torch_dtype )
426
- out ._tensor = tt
427
- return out
442
+ # int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)
428
443
444
+ # Therefore, we treat `tensor.dtype` as a hint, and convert the
445
+ # original object *again*, this time with an explicit dtype.
446
+ dtyp = _dtypes .dtype_from_torch (tensor .dtype )
447
+ default = _dtypes .get_default_dtype_for (dtyp )
448
+ torch_dtype = _dtypes .torch_dtype_from (default )
429
449
430
- def array (object , dtype = None , * , copy = True , order = 'K' , subok = False , ndmin = 0 ,
431
- like = None ):
432
- _util .subok_not_ok (like , subok )
433
- if order != 'K' :
434
- raise NotImplementedError
435
-
436
- if isinstance (object , (list , tuple )):
437
- obj = asarray (object )
438
- return array (obj , dtype , copy = copy , order = order , subok = subok ,
439
- ndmin = ndmin , like = like )
450
+ tensor = torch .as_tensor (object , dtype = torch_dtype )
440
451
441
- if isinstance (object , ndarray ):
442
- result = object ._tensor
443
-
444
- if dtype != object .dtype :
445
- torch_dtype = _dtypes .torch_dtype_from (dtype )
446
- result = result .to (torch_dtype )
447
- else :
452
+ # type cast if requested
453
+ if dtype is not None :
448
454
torch_dtype = _dtypes .torch_dtype_from (dtype )
449
- result = torch .as_tensor (object , dtype = torch_dtype )
455
+ tensor = tensor .to (torch_dtype )
456
+ base = None
450
457
458
+ # adjust ndim if needed
459
+ ndim_extra = ndmin - tensor .ndim
460
+ if ndim_extra > 0 :
461
+ tensor = tensor .view ((1 ,)* ndim_extra + tensor .shape )
462
+ base = None
463
+
464
+ # copy if requested
451
465
if copy :
452
- result = result .clone ()
466
+ tensor = tensor .clone ()
467
+ base = None
453
468
454
- ndim_extra = ndmin - result .ndim
455
- if ndim_extra > 0 :
456
- result = result .reshape ((1 ,)* ndim_extra + result .shape )
457
- out = ndarray ()
458
- out ._tensor = result
459
- return out
469
+ return ndarray ._from_tensor_and_base (tensor , base )
470
+
471
+
472
+ def asarray (a , dtype = None , order = None , * , like = None ):
473
+ if order is None :
474
+ order = 'K'
475
+ return array (a , dtype = dtype , order = order , like = like , copy = False , ndmin = 0 )
460
476
461
477
462
478
0 commit comments