Skip to content

Commit a8adec8

Browse files
committed
MAINT: remove ndarray.get(), add a public ndarray.tensor attribute
1 parent 9cbbac6 commit a8adec8

File tree

7 files changed

+60
-66
lines changed

7 files changed

+60
-66
lines changed

torch_np/_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def out_shape_dtype(func):
1717
@functools.wraps(func)
1818
def wrapped(*args, out=None, **kwds):
1919
if out is not None:
20-
kwds.update({"out_shape_dtype": (out.get().dtype, out.get().shape)})
20+
kwds.update({"out_shape_dtype": (out.tensor.dtype, out.tensor.shape)})
2121
result_tensor = func(*args, **kwds)
2222
return _helpers.result_or_out(result_tensor, out)
2323

torch_np/_dtypes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __new__(self, value):
3131
value = {"inf": torch.inf, "nan": torch.nan}[value]
3232

3333
if isinstance(value, _ndarray.ndarray):
34-
tensor = value.get()
34+
tensor = value.tensor
3535
else:
3636
try:
3737
tensor = torch.as_tensor(value, dtype=self.torch_dtype)
@@ -49,7 +49,7 @@ def __new__(self, value):
4949
# and here we follow the second approach and create a new object
5050
# *for all inputs*.
5151
#
52-
return _ndarray.ndarray._from_tensor(tensor)
52+
return _ndarray.ndarray(tensor)
5353

5454

5555
##### these are abstract types
@@ -317,7 +317,7 @@ def __repr__(self):
317317
@property
318318
def itemsize(self):
319319
elem = self.type(1)
320-
return elem.get().element_size()
320+
return elem.tensor.element_size()
321321

322322
def __getstate__(self):
323323
return self._scalar_type

torch_np/_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def ufunc_preprocess(
4949

5050
out_shape_dtype = None
5151
if out is not None:
52-
out_shape_dtype = (out.get().dtype, out.get().shape)
52+
out_shape_dtype = (out.tensor.dtype, out.tensor.shape)
5353

5454
tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting)
5555

@@ -77,7 +77,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
7777
f"Bad size of the out array: out.shape = {out_array.shape}"
7878
f" while result.shape = {result_tensor.shape}."
7979
)
80-
out_tensor = out_array.get()
80+
out_tensor = out_array.tensor
8181
out_tensor.copy_(result_tensor)
8282
return out_array
8383
else:
@@ -87,7 +87,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
8787
def array_from(tensor, base=None):
8888
from ._ndarray import ndarray
8989

90-
return ndarray._from_tensor(tensor)
90+
return ndarray(tensor)
9191

9292

9393
def tuple_arrays_from(result):
@@ -108,7 +108,7 @@ def ndarrays_to_tensors(*inputs):
108108
elif len(inputs) == 1:
109109
input_ = inputs[0]
110110
if isinstance(input_, ndarray):
111-
return input_.get()
111+
return input_.tensor
112112
elif isinstance(input_, tuple):
113113
result = []
114114
for sub_input in input_:
@@ -126,4 +126,4 @@ def to_tensors(*inputs):
126126
"""Convert all array_likes from `inputs` to tensors."""
127127
from ._ndarray import asarray, ndarray
128128

129-
return tuple(asarray(value).get() for value in inputs)
129+
return tuple(asarray(value).tensor for value in inputs)

torch_np/_ndarray.py

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -62,42 +62,36 @@ def __getitem__(self, key):
6262

6363

6464
class ndarray:
65-
def __init__(self):
66-
self._tensor = torch.Tensor()
67-
68-
@classmethod
69-
def _from_tensor(cls, tensor):
70-
self = cls()
71-
self._tensor = tensor
72-
return self
73-
74-
def get(self):
75-
return self._tensor
65+
def __init__(self, t=None):
66+
if t is None:
67+
self.tensor = torch.Tensor()
68+
else:
69+
self.tensor = torch.as_tensor(t)
7670

7771
@property
7872
def shape(self):
79-
return tuple(self._tensor.shape)
73+
return tuple(self.tensor.shape)
8074

8175
@property
8276
def size(self):
83-
return self._tensor.numel()
77+
return self.tensor.numel()
8478

8579
@property
8680
def ndim(self):
87-
return self._tensor.ndim
81+
return self.tensor.ndim
8882

8983
@property
9084
def dtype(self):
91-
return _dtypes.dtype(self._tensor.dtype)
85+
return _dtypes.dtype(self.tensor.dtype)
9286

9387
@property
9488
def strides(self):
95-
elsize = self._tensor.element_size()
96-
return tuple(stride * elsize for stride in self._tensor.stride())
89+
elsize = self.tensor.element_size()
90+
return tuple(stride * elsize for stride in self.tensor.stride())
9791

9892
@property
9993
def itemsize(self):
100-
return self._tensor.element_size()
94+
return self.tensor.element_size()
10195

10296
@property
10397
def flags(self):
@@ -106,15 +100,15 @@ def flags(self):
106100
# check if F contiguous
107101
from itertools import accumulate
108102

109-
f_strides = tuple(accumulate(list(self._tensor.shape), func=lambda x, y: x * y))
103+
f_strides = tuple(accumulate(list(self.tensor.shape), func=lambda x, y: x * y))
110104
f_strides = (1,) + f_strides[:-1]
111-
is_f_contiguous = f_strides == self._tensor.stride()
105+
is_f_contiguous = f_strides == self.tensor.stride()
112106

113107
return Flags(
114108
{
115-
"C_CONTIGUOUS": self._tensor.is_contiguous(),
109+
"C_CONTIGUOUS": self.tensor.is_contiguous(),
116110
"F_CONTIGUOUS": is_f_contiguous,
117-
"OWNDATA": self._tensor._base is None,
111+
"OWNDATA": self.tensor._base is None,
118112
"WRITEABLE": True, # pytorch does not have readonly tensors
119113
}
120114
)
@@ -129,38 +123,38 @@ def real(self):
129123

130124
@real.setter
131125
def real(self, value):
132-
self._tensor.real = asarray(value).get()
126+
self.tensor.real = asarray(value).tensor
133127

134128
@property
135129
def imag(self):
136130
return _funcs.imag(self)
137131

138132
@imag.setter
139133
def imag(self, value):
140-
self._tensor.imag = asarray(value).get()
134+
self.tensor.imag = asarray(value).tensor
141135

142136
round = _funcs.round
143137

144138
# ctors
145139
def astype(self, dtype):
146140
newt = ndarray()
147141
torch_dtype = _dtypes.dtype(dtype).torch_dtype
148-
newt._tensor = self._tensor.to(torch_dtype)
142+
newt.tensor = self.tensor.to(torch_dtype)
149143
return newt
150144

151145
def copy(self, order="C"):
152146
if order != "C":
153147
raise NotImplementedError
154-
tensor = self._tensor.clone()
155-
return ndarray._from_tensor(tensor)
148+
tensor = self.tensor.clone()
149+
return ndarray(tensor)
156150

157151
def tolist(self):
158-
return self._tensor.tolist()
152+
return self.tensor.tolist()
159153

160154
### niceties ###
161155
def __str__(self):
162156
return (
163-
str(self._tensor)
157+
str(self.tensor)
164158
.replace("tensor", "array_w")
165159
.replace("dtype=torch.", "dtype=")
166160
)
@@ -191,7 +185,7 @@ def __ne__(self, other):
191185

192186
def __bool__(self):
193187
try:
194-
return bool(self._tensor)
188+
return bool(self.tensor)
195189
except RuntimeError:
196190
raise ValueError(
197191
"The truth value of an array with more than one "
@@ -200,35 +194,35 @@ def __bool__(self):
200194

201195
def __index__(self):
202196
try:
203-
return operator.index(self._tensor.item())
197+
return operator.index(self.tensor.item())
204198
except Exception:
205199
mesg = "only integer scalar arrays can be converted to a scalar index"
206200
raise TypeError(mesg)
207201

208202
def __float__(self):
209-
return float(self._tensor)
203+
return float(self.tensor)
210204

211205
def __complex__(self):
212206
try:
213-
return complex(self._tensor)
207+
return complex(self.tensor)
214208
except ValueError as e:
215209
raise TypeError(*e.args)
216210

217211
def __int__(self):
218-
return int(self._tensor)
212+
return int(self.tensor)
219213

220214
# XXX : are single-element ndarrays scalars?
221215
# in numpy, only array scalars have the `is_integer` method
222216
def is_integer(self):
223217
try:
224-
result = int(self._tensor) == self._tensor
218+
result = int(self.tensor) == self.tensor
225219
except Exception:
226220
result = False
227221
return result
228222

229223
### sequence ###
230224
def __len__(self):
231-
return self._tensor.shape[0]
225+
return self.tensor.shape[0]
232226

233227
### arithmetic ###
234228

@@ -354,8 +348,8 @@ def reshape(self, *shape, order="C"):
354348

355349
def sort(self, axis=-1, kind=None, order=None):
356350
# ndarray.sort works in-place
357-
result = _impl.sort(self._tensor, axis, kind, order)
358-
self._tensor = result
351+
result = _impl.sort(self.tensor, axis, kind, order)
352+
self.tensor = result
359353

360354
argsort = _funcs.argsort
361355
searchsorted = _funcs.searchsorted
@@ -392,13 +386,13 @@ def _upcast_int_indices(index):
392386
def __getitem__(self, index):
393387
index = _helpers.ndarrays_to_tensors(index)
394388
index = ndarray._upcast_int_indices(index)
395-
return ndarray._from_tensor(self._tensor.__getitem__(index))
389+
return ndarray(self.tensor.__getitem__(index))
396390

397391
def __setitem__(self, index, value):
398392
index = _helpers.ndarrays_to_tensors(index)
399393
index = ndarray._upcast_int_indices(index)
400394
value = _helpers.ndarrays_to_tensors(value)
401-
return self._tensor.__setitem__(index, value)
395+
return self.tensor.__setitem__(index, value)
402396

403397

404398
# This is the ideally the only place which talks to ndarray directly.
@@ -420,22 +414,22 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
420414
a1 = []
421415
for elem in obj:
422416
if isinstance(elem, ndarray):
423-
a1.append(elem.get().tolist())
417+
a1.append(elem.tensor.tolist())
424418
else:
425419
a1.append(elem)
426420
obj = a1
427421

428422
# is obj an ndarray already?
429423
if isinstance(obj, ndarray):
430-
obj = obj.get()
424+
obj = obj.tensor
431425

432426
# is a specific dtype requrested?
433427
torch_dtype = None
434428
if dtype is not None:
435429
torch_dtype = _dtypes.dtype(dtype).torch_dtype
436430

437431
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
438-
return ndarray._from_tensor(tensor)
432+
return ndarray(tensor)
439433

440434

441435
def asarray(a, dtype=None, order=None, *, like=None):

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_empty_tuple_index(self):
115115
# Empty tuple index creates a view
116116
a = np.array([1, 2, 3])
117117
assert_equal(a[()], a)
118-
assert_(a[()].get()._base is a.get())
118+
assert_(a[()].tensor._base is a.tensor)
119119
a = np.array(0)
120120
pytest.skip(
121121
"torch doesn't have scalar types with distinct instancing behaviours"
@@ -164,7 +164,7 @@ def test_ellipsis_index(self):
164164
assert_(a[...] is not a)
165165
assert_equal(a[...], a)
166166
# `a[...]` was `a` in numpy <1.9.
167-
assert_(a[...].get()._base is a.get())
167+
assert_(a[...].tensor._base is a.tensor)
168168

169169
# Slicing with ellipsis can skip an
170170
# arbitrary number of dimensions

torch_np/tests/numpy_tests/lib/test_shape_base_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def test_basic(self):
597597
assert type(res) is np.ndarray
598598

599599
aa = np.ones((3, 1, 4, 1, 1))
600-
assert aa.squeeze().get()._base is aa.get()
600+
assert aa.squeeze().tensor._base is aa.tensor
601601

602602
def test_squeeze_axis(self):
603603
A = [[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]

0 commit comments

Comments
 (0)