@@ -62,48 +62,36 @@ def __getitem__(self, key):
62
62
63
63
64
64
class ndarray :
65
- def __init__ (self ):
66
- self ._tensor = torch .Tensor ()
67
- self ._base = None
68
-
69
- @classmethod
70
- def _from_tensor_and_base (cls , tensor , base ):
71
- self = cls ()
72
- self ._tensor = tensor
73
- self ._base = base
74
- return self
75
-
76
- def get (self ):
77
- return self ._tensor
65
+ def __init__ (self , t = None ):
66
+ if t is None :
67
+ self .tensor = torch .Tensor ()
68
+ else :
69
+ self .tensor = torch .as_tensor (t )
78
70
79
71
@property
80
72
def shape (self ):
81
- return tuple (self ._tensor .shape )
73
+ return tuple (self .tensor .shape )
82
74
83
75
@property
84
76
def size (self ):
85
- return self ._tensor .numel ()
77
+ return self .tensor .numel ()
86
78
87
79
@property
88
80
def ndim (self ):
89
- return self ._tensor .ndim
81
+ return self .tensor .ndim
90
82
91
83
@property
92
84
def dtype (self ):
93
- return _dtypes .dtype (self ._tensor .dtype )
85
+ return _dtypes .dtype (self .tensor .dtype )
94
86
95
87
@property
96
88
def strides (self ):
97
- elsize = self ._tensor .element_size ()
98
- return tuple (stride * elsize for stride in self ._tensor .stride ())
89
+ elsize = self .tensor .element_size ()
90
+ return tuple (stride * elsize for stride in self .tensor .stride ())
99
91
100
92
@property
101
93
def itemsize (self ):
102
- return self ._tensor .element_size ()
103
-
104
- @property
105
- def base (self ):
106
- return self ._base
94
+ return self .tensor .element_size ()
107
95
108
96
@property
109
97
def flags (self ):
@@ -112,15 +100,15 @@ def flags(self):
112
100
# check if F contiguous
113
101
from itertools import accumulate
114
102
115
- f_strides = tuple (accumulate (list (self ._tensor .shape ), func = lambda x , y : x * y ))
103
+ f_strides = tuple (accumulate (list (self .tensor .shape ), func = lambda x , y : x * y ))
116
104
f_strides = (1 ,) + f_strides [:- 1 ]
117
- is_f_contiguous = f_strides == self ._tensor .stride ()
105
+ is_f_contiguous = f_strides == self .tensor .stride ()
118
106
119
107
return Flags (
120
108
{
121
- "C_CONTIGUOUS" : self ._tensor .is_contiguous (),
109
+ "C_CONTIGUOUS" : self .tensor .is_contiguous (),
122
110
"F_CONTIGUOUS" : is_f_contiguous ,
123
- "OWNDATA" : self ._tensor ._base is None ,
111
+ "OWNDATA" : self .tensor ._base is None ,
124
112
"WRITEABLE" : True , # pytorch does not have readonly tensors
125
113
}
126
114
)
@@ -135,38 +123,38 @@ def real(self):
135
123
136
124
@real .setter
137
125
def real (self , value ):
138
- self ._tensor .real = asarray (value ).get ()
126
+ self .tensor .real = asarray (value ).tensor
139
127
140
128
@property
141
129
def imag (self ):
142
130
return _funcs .imag (self )
143
131
144
132
@imag .setter
145
133
def imag (self , value ):
146
- self ._tensor .imag = asarray (value ).get ()
134
+ self .tensor .imag = asarray (value ).tensor
147
135
148
136
round = _funcs .round
149
137
150
138
# ctors
151
139
def astype (self , dtype ):
152
140
newt = ndarray ()
153
141
torch_dtype = _dtypes .dtype (dtype ).torch_dtype
154
- newt ._tensor = self ._tensor .to (torch_dtype )
142
+ newt .tensor = self .tensor .to (torch_dtype )
155
143
return newt
156
144
157
145
def copy (self , order = "C" ):
158
146
if order != "C" :
159
147
raise NotImplementedError
160
- tensor = self ._tensor .clone ()
161
- return ndarray . _from_tensor_and_base (tensor , None )
148
+ tensor = self .tensor .clone ()
149
+ return ndarray (tensor )
162
150
163
151
def tolist (self ):
164
- return self ._tensor .tolist ()
152
+ return self .tensor .tolist ()
165
153
166
154
### niceties ###
167
155
def __str__ (self ):
168
156
return (
169
- str (self ._tensor )
157
+ str (self .tensor )
170
158
.replace ("tensor" , "array_w" )
171
159
.replace ("dtype=torch." , "dtype=" )
172
160
)
@@ -197,7 +185,7 @@ def __ne__(self, other):
197
185
198
186
def __bool__ (self ):
199
187
try :
200
- return bool (self ._tensor )
188
+ return bool (self .tensor )
201
189
except RuntimeError :
202
190
raise ValueError (
203
191
"The truth value of an array with more than one "
@@ -206,35 +194,35 @@ def __bool__(self):
206
194
207
195
def __index__ (self ):
208
196
try :
209
- return operator .index (self ._tensor .item ())
197
+ return operator .index (self .tensor .item ())
210
198
except Exception :
211
199
mesg = "only integer scalar arrays can be converted to a scalar index"
212
200
raise TypeError (mesg )
213
201
214
202
def __float__ (self ):
215
- return float (self ._tensor )
203
+ return float (self .tensor )
216
204
217
205
def __complex__ (self ):
218
206
try :
219
- return complex (self ._tensor )
207
+ return complex (self .tensor )
220
208
except ValueError as e :
221
209
raise TypeError (* e .args )
222
210
223
211
def __int__ (self ):
224
- return int (self ._tensor )
212
+ return int (self .tensor )
225
213
226
214
# XXX : are single-element ndarrays scalars?
227
215
# in numpy, only array scalars have the `is_integer` method
228
216
def is_integer (self ):
229
217
try :
230
- result = int (self ._tensor ) == self ._tensor
218
+ result = int (self .tensor ) == self .tensor
231
219
except Exception :
232
220
result = False
233
221
return result
234
222
235
223
### sequence ###
236
224
def __len__ (self ):
237
- return self ._tensor .shape [0 ]
225
+ return self .tensor .shape [0 ]
238
226
239
227
### arithmetic ###
240
228
@@ -360,8 +348,8 @@ def reshape(self, *shape, order="C"):
360
348
361
349
def sort (self , axis = - 1 , kind = None , order = None ):
362
350
# ndarray.sort works in-place
363
- result = _impl .sort (self ._tensor , axis , kind , order )
364
- self ._tensor = result
351
+ result = _impl .sort (self .tensor , axis , kind , order )
352
+ self .tensor = result
365
353
366
354
argsort = _funcs .argsort
367
355
searchsorted = _funcs .searchsorted
@@ -398,13 +386,13 @@ def _upcast_int_indices(index):
398
386
def __getitem__ (self , index ):
399
387
index = _helpers .ndarrays_to_tensors (index )
400
388
index = ndarray ._upcast_int_indices (index )
401
- return ndarray . _from_tensor_and_base (self ._tensor .__getitem__ (index ), self )
389
+ return ndarray (self .tensor .__getitem__ (index ))
402
390
403
391
def __setitem__ (self , index , value ):
404
392
index = _helpers .ndarrays_to_tensors (index )
405
393
index = ndarray ._upcast_int_indices (index )
406
394
value = _helpers .ndarrays_to_tensors (value )
407
- return self ._tensor .__setitem__ (index , value )
395
+ return self .tensor .__setitem__ (index , value )
408
396
409
397
410
398
# This is the ideally the only place which talks to ndarray directly.
@@ -426,25 +414,22 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
426
414
a1 = []
427
415
for elem in obj :
428
416
if isinstance (elem , ndarray ):
429
- a1 .append (elem .get () .tolist ())
417
+ a1 .append (elem .tensor .tolist ())
430
418
else :
431
419
a1 .append (elem )
432
420
obj = a1
433
421
434
422
# is obj an ndarray already?
435
- base = None
436
423
if isinstance (obj , ndarray ):
437
- obj = obj ._tensor
438
- base = obj
424
+ obj = obj .tensor
439
425
440
426
# is a specific dtype requrested?
441
427
torch_dtype = None
442
428
if dtype is not None :
443
429
torch_dtype = _dtypes .dtype (dtype ).torch_dtype
444
- base = None
445
430
446
431
tensor = _util ._coerce_to_tensor (obj , torch_dtype , copy , ndmin )
447
- return ndarray . _from_tensor_and_base (tensor , base )
432
+ return ndarray (tensor )
448
433
449
434
450
435
def asarray (a , dtype = None , order = None , * , like = None ):
@@ -453,10 +438,6 @@ def asarray(a, dtype=None, order=None, *, like=None):
453
438
return array (a , dtype = dtype , order = order , like = like , copy = False , ndmin = 0 )
454
439
455
440
456
- def maybe_set_base (tensor , base ):
457
- return ndarray ._from_tensor_and_base (tensor , base )
458
-
459
-
460
441
###### dtype routines
461
442
462
443
0 commit comments