@@ -81,6 +81,18 @@ def base(self):
81
81
def T (self ):
82
82
return self .transpose ()
83
83
84
+ @property
85
+ def real (self ):
86
+ return asarray (self ._tensor .real )
87
+
88
+ @property
89
+ def imag (self ):
90
+ try :
91
+ return asarray (self ._tensor .imag )
92
+ except RuntimeError :
93
+ zeros = torch .zeros_like (self ._tensor )
94
+ return ndarray ._from_tensor_and_base (zeros , None )
95
+
84
96
# ctors
85
97
def astype (self , dtype ):
86
98
newt = ndarray ()
@@ -102,6 +114,13 @@ def __str__(self):
102
114
103
115
### comparisons ###
104
116
def __eq__ (self , other ):
117
+ try :
118
+ t_other = asarray (other ).get
119
+ except RuntimeError :
120
+ # Failed to convert other to array: definitely not equal.
121
+ # TODO: generalize, delegate to ufuncs
122
+ falsy = torch .full (self .shape , fill_value = False , dtype = bool )
123
+ return asarray (falsy )
105
124
return asarray (self ._tensor == asarray (other ).get ())
106
125
107
126
def __neq__ (self , other ):
@@ -119,7 +138,6 @@ def __ge__(self, other):
119
138
def __le__ (self , other ):
120
139
return asarray (self ._tensor <= asarray (other ).get ())
121
140
122
-
123
141
def __bool__ (self ):
124
142
try :
125
143
return bool (self ._tensor )
@@ -141,6 +159,9 @@ def __hash__(self):
141
159
def __float__ (self ):
142
160
return float (self ._tensor )
143
161
162
+ def __int__ (self ):
163
+ return int (self ._tensor )
164
+
144
165
# XXX : are single-element ndarrays scalars?
145
166
def is_integer (self ):
146
167
if self .shape == ():
@@ -167,7 +188,10 @@ def __iadd__(self, other):
167
188
168
189
def __sub__ (self , other ):
169
190
other_tensor = asarray (other ).get ()
170
- return asarray (self ._tensor .__sub__ (other_tensor ))
191
+ try :
192
+ return asarray (self ._tensor .__sub__ (other_tensor ))
193
+ except RuntimeError as e :
194
+ raise TypeError (e .args )
171
195
172
196
def __mul__ (self , other ):
173
197
other_tensor = asarray (other ).get ()
@@ -177,10 +201,30 @@ def __rmul__(self, other):
177
201
other_tensor = asarray (other ).get ()
178
202
return asarray (self ._tensor .__rmul__ (other_tensor ))
179
203
204
+ def __floordiv__ (self , other ):
205
+ other_tensor = asarray (other ).get ()
206
+ return asarray (self ._tensor .__floordiv__ (other_tensor ))
207
+
208
+ def __ifloordiv__ (self , other ):
209
+ other_tensor = asarray (other ).get ()
210
+ return asarray (self ._tensor .__ifloordiv__ (other_tensor ))
211
+
180
212
def __truediv__ (self , other ):
181
213
other_tensor = asarray (other ).get ()
182
214
return asarray (self ._tensor .__truediv__ (other_tensor ))
183
215
216
+ def __itruediv__ (self , other ):
217
+ other_tensor = asarray (other ).get ()
218
+ return asarray (self ._tensor .__itruediv__ (other_tensor ))
219
+
220
+ def __mod__ (self , other ):
221
+ other_tensor = asarray (other ).get ()
222
+ return asarray (self ._tensor .__mod__ (other_tensor ))
223
+
224
+ def __imod__ (self , other ):
225
+ other_tensor = asarray (other ).get ()
226
+ return asarray (self ._tensor .__imod__ (other_tensor ))
227
+
184
228
def __or__ (self , other ):
185
229
other_tensor = asarray (other ).get ()
186
230
return asarray (self ._tensor .__or__ (other_tensor ))
@@ -189,10 +233,22 @@ def __ior__(self, other):
189
233
other_tensor = asarray (other ).get ()
190
234
return asarray (self ._tensor .__ior__ (other_tensor ))
191
235
192
-
193
236
def __invert__ (self ):
194
237
return asarray (self ._tensor .__invert__ ())
195
238
239
+ def __abs__ (self ):
240
+ return asarray (self ._tensor .__abs__ ())
241
+
242
+ def __neg__ (self ):
243
+ try :
244
+ return asarray (self ._tensor .__neg__ ())
245
+ except RuntimeError as e :
246
+ raise TypeError (e .args )
247
+
248
+ def __pow__ (self , exponent ):
249
+ exponent_tensor = asarray (exponent ).get ()
250
+ return asarray (self ._tensor .__pow__ (exponent_tensor ))
251
+
196
252
### methods to match namespace functions
197
253
198
254
def squeeze (self , axis = None ):
@@ -301,7 +357,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal
301
357
302
358
if dtype is None :
303
359
dtype = self .dtype
304
- if not _dtypes .is_floating (dtype ):
360
+ if _dtypes .is_integer (dtype ):
305
361
dtype = _dtypes .default_float_type ()
306
362
torch_dtype = _dtypes .torch_dtype_from (dtype )
307
363
@@ -321,7 +377,7 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,
321
377
322
378
if dtype is None :
323
379
dtype = self .dtype
324
- if not _dtypes .is_floating (dtype ):
380
+ if _dtypes .is_integer (dtype ):
325
381
dtype = _dtypes .default_float_type ()
326
382
torch_dtype = _dtypes .torch_dtype_from (dtype )
327
383
@@ -343,67 +399,80 @@ def __setitem__(self, index, value):
343
399
return self ._tensor .__setitem__ (index , value )
344
400
345
401
346
- def asarray (a , dtype = None , order = None , * , like = None ):
347
- _util .subok_not_ok (like )
348
- if order is not None :
402
+ # This is the ideally the only place which talks to ndarray directly.
403
+ # The rest goes through asarray (preferred) or array.
404
+
405
+ def array (object , dtype = None , * , copy = True , order = 'K' , subok = False , ndmin = 0 ,
406
+ like = None ):
407
+ _util .subok_not_ok (like , subok )
408
+ if order != 'K' :
349
409
raise NotImplementedError
350
410
351
- if isinstance ( a , ndarray ):
352
- if dtype is not None and dtype != a . dtype :
353
- a = a . astype ( dtype )
354
- return a
411
+ # a happy path
412
+ if isinstance ( object , ndarray ) :
413
+ if copy is False and dtype is None and ndmin <= object . ndim :
414
+ return object
355
415
356
- if isinstance ( a , ( list , tuple )):
357
- # handle lists of ndarrays, [1, [2, 3], ndarray(4)] etc
416
+ # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
417
+ if isinstance ( object , ( list , tuple )):
358
418
a1 = []
359
- for elem in a :
419
+ for elem in object :
360
420
if isinstance (elem , ndarray ):
361
421
a1 .append (elem .get ().tolist ())
362
422
else :
363
423
a1 .append (elem )
424
+ object = a1
425
+
426
+ # get the tensor from "object"
427
+ if isinstance (object , ndarray ):
428
+ tensor = object ._tensor
429
+ base = object
430
+ elif isinstance (object , torch .Tensor ):
431
+ tensor = object
432
+ base = None
364
433
else :
365
- a1 = a
434
+ tensor = torch .as_tensor (object )
435
+ base = None
366
436
367
- torch_dtype = _dtypes .torch_dtype_from (dtype )
437
+ # At this point, `tensor.dtype` is the pytorch default. Our default may
438
+ # differ, so need to typecast. However, we cannot just do `tensor.to`,
439
+ # because if our desired dtype is wider then pytorch's, `tensor`
440
+ # may have lost precision:
368
441
369
- # This and array(...) are the only places which talk to ndarray directly.
370
- # The rest goes through asarray (preferred) or array.
371
- out = ndarray ()
372
- tt = torch .as_tensor (a1 , dtype = torch_dtype )
373
- out ._tensor = tt
374
- return out
442
+ # int(torch.as_tensor(1e12)) - 1e12 equals -4096 (try it!)
375
443
444
+ # Therefore, we treat `tensor.dtype` as a hint, and convert the
445
+ # original object *again*, this time with an explicit dtype.
446
+ dtyp = _dtypes .dtype_from_torch (tensor .dtype )
447
+ default = _dtypes .get_default_dtype_for (dtyp )
448
+ torch_dtype = _dtypes .torch_dtype_from (default )
376
449
377
- def array (object , dtype = None , * , copy = True , order = 'K' , subok = False , ndmin = 0 ,
378
- like = None ):
379
- _util .subok_not_ok (like , subok )
380
- if order != 'K' :
381
- raise NotImplementedError
382
-
383
- if isinstance (object , (list , tuple )):
384
- obj = asarray (object )
385
- return array (obj , dtype , copy = copy , order = order , subok = subok ,
386
- ndmin = ndmin , like = like )
450
+ tensor = torch .as_tensor (object , dtype = torch_dtype )
387
451
388
- if isinstance (object , ndarray ):
389
- result = object ._tensor
390
-
391
- if dtype != object .dtype :
392
- torch_dtype = _dtypes .torch_dtype_from (dtype )
393
- result = result .to (torch_dtype )
394
- else :
452
+ # type cast if requested
453
+ if dtype is not None :
395
454
torch_dtype = _dtypes .torch_dtype_from (dtype )
396
- result = torch .as_tensor (object , dtype = torch_dtype )
455
+ tensor = tensor .to (torch_dtype )
456
+ base = None
397
457
458
+ # adjust ndim if needed
459
+ ndim_extra = ndmin - tensor .ndim
460
+ if ndim_extra > 0 :
461
+ tensor = tensor .view ((1 ,)* ndim_extra + tensor .shape )
462
+ base = None
463
+
464
+ # copy if requested
398
465
if copy :
399
- result = result .clone ()
466
+ tensor = tensor .clone ()
467
+ base = None
400
468
401
- ndim_extra = ndmin - result .ndim
402
- if ndim_extra > 0 :
403
- result = result .reshape ((1 ,)* ndim_extra + result .shape )
404
- out = ndarray ()
405
- out ._tensor = result
406
- return out
469
+ return ndarray ._from_tensor_and_base (tensor , base )
470
+
471
+
472
+ def asarray (a , dtype = None , order = None , * , like = None ):
473
+ if order is None :
474
+ order = 'K'
475
+ return array (a , dtype = dtype , order = order , like = like , copy = False , ndmin = 0 )
407
476
408
477
409
478
0 commit comments