@@ -62,42 +62,36 @@ def __getitem__(self, key):
62
62
63
63
64
64
class ndarray :
65
- def __init__ (self ):
66
- self ._tensor = torch .Tensor ()
67
-
68
- @classmethod
69
- def _from_tensor (cls , tensor ):
70
- self = cls ()
71
- self ._tensor = tensor
72
- return self
73
-
74
- def get (self ):
75
- return self ._tensor
65
+ def __init__ (self , t = None ):
66
+ if t is None :
67
+ self .tensor = torch .Tensor ()
68
+ else :
69
+ self .tensor = torch .as_tensor (t )
76
70
77
71
@property
78
72
def shape (self ):
79
- return tuple (self ._tensor .shape )
73
+ return tuple (self .tensor .shape )
80
74
81
75
@property
82
76
def size (self ):
83
- return self ._tensor .numel ()
77
+ return self .tensor .numel ()
84
78
85
79
@property
86
80
def ndim (self ):
87
- return self ._tensor .ndim
81
+ return self .tensor .ndim
88
82
89
83
@property
90
84
def dtype (self ):
91
- return _dtypes .dtype (self ._tensor .dtype )
85
+ return _dtypes .dtype (self .tensor .dtype )
92
86
93
87
@property
94
88
def strides (self ):
95
- elsize = self ._tensor .element_size ()
96
- return tuple (stride * elsize for stride in self ._tensor .stride ())
89
+ elsize = self .tensor .element_size ()
90
+ return tuple (stride * elsize for stride in self .tensor .stride ())
97
91
98
92
@property
99
93
def itemsize (self ):
100
- return self ._tensor .element_size ()
94
+ return self .tensor .element_size ()
101
95
102
96
@property
103
97
def flags (self ):
@@ -106,15 +100,15 @@ def flags(self):
106
100
# check if F contiguous
107
101
from itertools import accumulate
108
102
109
- f_strides = tuple (accumulate (list (self ._tensor .shape ), func = lambda x , y : x * y ))
103
+ f_strides = tuple (accumulate (list (self .tensor .shape ), func = lambda x , y : x * y ))
110
104
f_strides = (1 ,) + f_strides [:- 1 ]
111
- is_f_contiguous = f_strides == self ._tensor .stride ()
105
+ is_f_contiguous = f_strides == self .tensor .stride ()
112
106
113
107
return Flags (
114
108
{
115
- "C_CONTIGUOUS" : self ._tensor .is_contiguous (),
109
+ "C_CONTIGUOUS" : self .tensor .is_contiguous (),
116
110
"F_CONTIGUOUS" : is_f_contiguous ,
117
- "OWNDATA" : self ._tensor ._base is None ,
111
+ "OWNDATA" : self .tensor ._base is None ,
118
112
"WRITEABLE" : True , # pytorch does not have readonly tensors
119
113
}
120
114
)
@@ -129,38 +123,38 @@ def real(self):
129
123
130
124
@real .setter
131
125
def real (self , value ):
132
- self ._tensor .real = asarray (value ).get ()
126
+ self .tensor .real = asarray (value ).tensor
133
127
134
128
@property
135
129
def imag (self ):
136
130
return _funcs .imag (self )
137
131
138
132
@imag .setter
139
133
def imag (self , value ):
140
- self ._tensor .imag = asarray (value ).get ()
134
+ self .tensor .imag = asarray (value ).tensor
141
135
142
136
round = _funcs .round
143
137
144
138
# ctors
145
139
def astype (self , dtype ):
146
140
newt = ndarray ()
147
141
torch_dtype = _dtypes .dtype (dtype ).torch_dtype
148
- newt ._tensor = self ._tensor .to (torch_dtype )
142
+ newt .tensor = self .tensor .to (torch_dtype )
149
143
return newt
150
144
151
145
def copy (self , order = "C" ):
152
146
if order != "C" :
153
147
raise NotImplementedError
154
- tensor = self ._tensor .clone ()
155
- return ndarray . _from_tensor (tensor )
148
+ tensor = self .tensor .clone ()
149
+ return ndarray (tensor )
156
150
157
151
def tolist (self ):
158
- return self ._tensor .tolist ()
152
+ return self .tensor .tolist ()
159
153
160
154
### niceties ###
161
155
def __str__ (self ):
162
156
return (
163
- str (self ._tensor )
157
+ str (self .tensor )
164
158
.replace ("tensor" , "array_w" )
165
159
.replace ("dtype=torch." , "dtype=" )
166
160
)
@@ -191,7 +185,7 @@ def __ne__(self, other):
191
185
192
186
def __bool__ (self ):
193
187
try :
194
- return bool (self ._tensor )
188
+ return bool (self .tensor )
195
189
except RuntimeError :
196
190
raise ValueError (
197
191
"The truth value of an array with more than one "
@@ -200,35 +194,35 @@ def __bool__(self):
200
194
201
195
def __index__ (self ):
202
196
try :
203
- return operator .index (self ._tensor .item ())
197
+ return operator .index (self .tensor .item ())
204
198
except Exception :
205
199
mesg = "only integer scalar arrays can be converted to a scalar index"
206
200
raise TypeError (mesg )
207
201
208
202
def __float__ (self ):
209
- return float (self ._tensor )
203
+ return float (self .tensor )
210
204
211
205
def __complex__ (self ):
212
206
try :
213
- return complex (self ._tensor )
207
+ return complex (self .tensor )
214
208
except ValueError as e :
215
209
raise TypeError (* e .args )
216
210
217
211
def __int__ (self ):
218
- return int (self ._tensor )
212
+ return int (self .tensor )
219
213
220
214
# XXX : are single-element ndarrays scalars?
221
215
# in numpy, only array scalars have the `is_integer` method
222
216
def is_integer (self ):
223
217
try :
224
- result = int (self ._tensor ) == self ._tensor
218
+ result = int (self .tensor ) == self .tensor
225
219
except Exception :
226
220
result = False
227
221
return result
228
222
229
223
### sequence ###
230
224
def __len__ (self ):
231
- return self ._tensor .shape [0 ]
225
+ return self .tensor .shape [0 ]
232
226
233
227
### arithmetic ###
234
228
@@ -354,8 +348,8 @@ def reshape(self, *shape, order="C"):
354
348
355
349
def sort (self , axis = - 1 , kind = None , order = None ):
356
350
# ndarray.sort works in-place
357
- result = _impl .sort (self ._tensor , axis , kind , order )
358
- self ._tensor = result
351
+ result = _impl .sort (self .tensor , axis , kind , order )
352
+ self .tensor = result
359
353
360
354
argsort = _funcs .argsort
361
355
searchsorted = _funcs .searchsorted
@@ -392,13 +386,13 @@ def _upcast_int_indices(index):
392
386
def __getitem__ (self , index ):
393
387
index = _helpers .ndarrays_to_tensors (index )
394
388
index = ndarray ._upcast_int_indices (index )
395
- return ndarray . _from_tensor (self ._tensor .__getitem__ (index ))
389
+ return ndarray (self .tensor .__getitem__ (index ))
396
390
397
391
def __setitem__ (self , index , value ):
398
392
index = _helpers .ndarrays_to_tensors (index )
399
393
index = ndarray ._upcast_int_indices (index )
400
394
value = _helpers .ndarrays_to_tensors (value )
401
- return self ._tensor .__setitem__ (index , value )
395
+ return self .tensor .__setitem__ (index , value )
402
396
403
397
404
398
# This is the ideally the only place which talks to ndarray directly.
@@ -420,22 +414,22 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
420
414
a1 = []
421
415
for elem in obj :
422
416
if isinstance (elem , ndarray ):
423
- a1 .append (elem .get () .tolist ())
417
+ a1 .append (elem .tensor .tolist ())
424
418
else :
425
419
a1 .append (elem )
426
420
obj = a1
427
421
428
422
# is obj an ndarray already?
429
423
if isinstance (obj , ndarray ):
430
- obj = obj .get ()
424
+ obj = obj .tensor
431
425
432
426
# is a specific dtype requrested?
433
427
torch_dtype = None
434
428
if dtype is not None :
435
429
torch_dtype = _dtypes .dtype (dtype ).torch_dtype
436
430
437
431
tensor = _util ._coerce_to_tensor (obj , torch_dtype , copy , ndmin )
438
- return ndarray . _from_tensor (tensor )
432
+ return ndarray (tensor )
439
433
440
434
441
435
def asarray (a , dtype = None , order = None , * , like = None ):
0 commit comments