Skip to content

Commit f664001

Browse files
authored
Merge pull request #80 from Quansight-Labs/nuke_base
API: remove ndarray.base
2 parents 023d453 + 857f601 commit f664001

File tree

8 files changed

+82
-107
lines changed

8 files changed

+82
-107
lines changed

torch_np/_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def out_shape_dtype(func):
1717
@functools.wraps(func)
1818
def wrapped(*args, out=None, **kwds):
1919
if out is not None:
20-
kwds.update({"out_shape_dtype": (out.get().dtype, out.get().shape)})
20+
kwds.update({"out_shape_dtype": (out.tensor.dtype, out.tensor.shape)})
2121
result_tensor = func(*args, **kwds)
2222
return _helpers.result_or_out(result_tensor, out)
2323

torch_np/_dtypes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __new__(self, value):
3131
value = {"inf": torch.inf, "nan": torch.nan}[value]
3232

3333
if isinstance(value, _ndarray.ndarray):
34-
tensor = value.get()
34+
tensor = value.tensor
3535
else:
3636
try:
3737
tensor = torch.as_tensor(value, dtype=self.torch_dtype)
@@ -49,7 +49,7 @@ def __new__(self, value):
4949
# and here we follow the second approach and create a new object
5050
# *for all inputs*.
5151
#
52-
return _ndarray.ndarray._from_tensor_and_base(tensor, None)
52+
return _ndarray.ndarray(tensor)
5353

5454

5555
##### these are abstract types
@@ -317,7 +317,7 @@ def __repr__(self):
317317
@property
318318
def itemsize(self):
319319
elem = self.type(1)
320-
return elem.get().element_size()
320+
return elem.tensor.element_size()
321321

322322
def __getstate__(self):
323323
return self._scalar_type

torch_np/_helpers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def ufunc_preprocess(
4949

5050
out_shape_dtype = None
5151
if out is not None:
52-
out_shape_dtype = (out.get().dtype, out.get().shape)
52+
out_shape_dtype = (out.tensor.dtype, out.tensor.shape)
5353

5454
tensors = _util.cast_and_broadcast(tensors, out_shape_dtype, casting)
5555

@@ -77,7 +77,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
7777
f"Bad size of the out array: out.shape = {out_array.shape}"
7878
f" while result.shape = {result_tensor.shape}."
7979
)
80-
out_tensor = out_array.get()
80+
out_tensor = out_array.tensor
8181
out_tensor.copy_(result_tensor)
8282
return out_array
8383
else:
@@ -87,8 +87,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
8787
def array_from(tensor, base=None):
8888
from ._ndarray import ndarray
8989

90-
base = base if isinstance(base, ndarray) else None
91-
return ndarray._from_tensor_and_base(tensor, base) # XXX: nuke .base
90+
return ndarray(tensor)
9291

9392

9493
def tuple_arrays_from(result):
@@ -109,7 +108,7 @@ def ndarrays_to_tensors(*inputs):
109108
elif len(inputs) == 1:
110109
input_ = inputs[0]
111110
if isinstance(input_, ndarray):
112-
return input_.get()
111+
return input_.tensor
113112
elif isinstance(input_, tuple):
114113
result = []
115114
for sub_input in input_:
@@ -127,4 +126,4 @@ def to_tensors(*inputs):
127126
"""Convert all array_likes from `inputs` to tensors."""
128127
from ._ndarray import asarray, ndarray
129128

130-
return tuple(asarray(value).get() for value in inputs)
129+
return tuple(asarray(value).tensor for value in inputs)

torch_np/_ndarray.py

Lines changed: 37 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -62,48 +62,36 @@ def __getitem__(self, key):
6262

6363

6464
class ndarray:
65-
def __init__(self):
66-
self._tensor = torch.Tensor()
67-
self._base = None
68-
69-
@classmethod
70-
def _from_tensor_and_base(cls, tensor, base):
71-
self = cls()
72-
self._tensor = tensor
73-
self._base = base
74-
return self
75-
76-
def get(self):
77-
return self._tensor
65+
def __init__(self, t=None):
66+
if t is None:
67+
self.tensor = torch.Tensor()
68+
else:
69+
self.tensor = torch.as_tensor(t)
7870

7971
@property
8072
def shape(self):
81-
return tuple(self._tensor.shape)
73+
return tuple(self.tensor.shape)
8274

8375
@property
8476
def size(self):
85-
return self._tensor.numel()
77+
return self.tensor.numel()
8678

8779
@property
8880
def ndim(self):
89-
return self._tensor.ndim
81+
return self.tensor.ndim
9082

9183
@property
9284
def dtype(self):
93-
return _dtypes.dtype(self._tensor.dtype)
85+
return _dtypes.dtype(self.tensor.dtype)
9486

9587
@property
9688
def strides(self):
97-
elsize = self._tensor.element_size()
98-
return tuple(stride * elsize for stride in self._tensor.stride())
89+
elsize = self.tensor.element_size()
90+
return tuple(stride * elsize for stride in self.tensor.stride())
9991

10092
@property
10193
def itemsize(self):
102-
return self._tensor.element_size()
103-
104-
@property
105-
def base(self):
106-
return self._base
94+
return self.tensor.element_size()
10795

10896
@property
10997
def flags(self):
@@ -112,15 +100,15 @@ def flags(self):
112100
# check if F contiguous
113101
from itertools import accumulate
114102

115-
f_strides = tuple(accumulate(list(self._tensor.shape), func=lambda x, y: x * y))
103+
f_strides = tuple(accumulate(list(self.tensor.shape), func=lambda x, y: x * y))
116104
f_strides = (1,) + f_strides[:-1]
117-
is_f_contiguous = f_strides == self._tensor.stride()
105+
is_f_contiguous = f_strides == self.tensor.stride()
118106

119107
return Flags(
120108
{
121-
"C_CONTIGUOUS": self._tensor.is_contiguous(),
109+
"C_CONTIGUOUS": self.tensor.is_contiguous(),
122110
"F_CONTIGUOUS": is_f_contiguous,
123-
"OWNDATA": self._tensor._base is None,
111+
"OWNDATA": self.tensor._base is None,
124112
"WRITEABLE": True, # pytorch does not have readonly tensors
125113
}
126114
)
@@ -135,38 +123,38 @@ def real(self):
135123

136124
@real.setter
137125
def real(self, value):
138-
self._tensor.real = asarray(value).get()
126+
self.tensor.real = asarray(value).tensor
139127

140128
@property
141129
def imag(self):
142130
return _funcs.imag(self)
143131

144132
@imag.setter
145133
def imag(self, value):
146-
self._tensor.imag = asarray(value).get()
134+
self.tensor.imag = asarray(value).tensor
147135

148136
round = _funcs.round
149137

150138
# ctors
151139
def astype(self, dtype):
152140
newt = ndarray()
153141
torch_dtype = _dtypes.dtype(dtype).torch_dtype
154-
newt._tensor = self._tensor.to(torch_dtype)
142+
newt.tensor = self.tensor.to(torch_dtype)
155143
return newt
156144

157145
def copy(self, order="C"):
158146
if order != "C":
159147
raise NotImplementedError
160-
tensor = self._tensor.clone()
161-
return ndarray._from_tensor_and_base(tensor, None)
148+
tensor = self.tensor.clone()
149+
return ndarray(tensor)
162150

163151
def tolist(self):
164-
return self._tensor.tolist()
152+
return self.tensor.tolist()
165153

166154
### niceties ###
167155
def __str__(self):
168156
return (
169-
str(self._tensor)
157+
str(self.tensor)
170158
.replace("tensor", "array_w")
171159
.replace("dtype=torch.", "dtype=")
172160
)
@@ -197,7 +185,7 @@ def __ne__(self, other):
197185

198186
def __bool__(self):
199187
try:
200-
return bool(self._tensor)
188+
return bool(self.tensor)
201189
except RuntimeError:
202190
raise ValueError(
203191
"The truth value of an array with more than one "
@@ -206,35 +194,35 @@ def __bool__(self):
206194

207195
def __index__(self):
208196
try:
209-
return operator.index(self._tensor.item())
197+
return operator.index(self.tensor.item())
210198
except Exception:
211199
mesg = "only integer scalar arrays can be converted to a scalar index"
212200
raise TypeError(mesg)
213201

214202
def __float__(self):
215-
return float(self._tensor)
203+
return float(self.tensor)
216204

217205
def __complex__(self):
218206
try:
219-
return complex(self._tensor)
207+
return complex(self.tensor)
220208
except ValueError as e:
221209
raise TypeError(*e.args)
222210

223211
def __int__(self):
224-
return int(self._tensor)
212+
return int(self.tensor)
225213

226214
# XXX : are single-element ndarrays scalars?
227215
# in numpy, only array scalars have the `is_integer` method
228216
def is_integer(self):
229217
try:
230-
result = int(self._tensor) == self._tensor
218+
result = int(self.tensor) == self.tensor
231219
except Exception:
232220
result = False
233221
return result
234222

235223
### sequence ###
236224
def __len__(self):
237-
return self._tensor.shape[0]
225+
return self.tensor.shape[0]
238226

239227
### arithmetic ###
240228

@@ -360,8 +348,8 @@ def reshape(self, *shape, order="C"):
360348

361349
def sort(self, axis=-1, kind=None, order=None):
362350
# ndarray.sort works in-place
363-
result = _impl.sort(self._tensor, axis, kind, order)
364-
self._tensor = result
351+
result = _impl.sort(self.tensor, axis, kind, order)
352+
self.tensor = result
365353

366354
argsort = _funcs.argsort
367355
searchsorted = _funcs.searchsorted
@@ -398,13 +386,13 @@ def _upcast_int_indices(index):
398386
def __getitem__(self, index):
399387
index = _helpers.ndarrays_to_tensors(index)
400388
index = ndarray._upcast_int_indices(index)
401-
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
389+
return ndarray(self.tensor.__getitem__(index))
402390

403391
def __setitem__(self, index, value):
404392
index = _helpers.ndarrays_to_tensors(index)
405393
index = ndarray._upcast_int_indices(index)
406394
value = _helpers.ndarrays_to_tensors(value)
407-
return self._tensor.__setitem__(index, value)
395+
return self.tensor.__setitem__(index, value)
408396

409397

410398
# This is the ideally the only place which talks to ndarray directly.
@@ -426,25 +414,22 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
426414
a1 = []
427415
for elem in obj:
428416
if isinstance(elem, ndarray):
429-
a1.append(elem.get().tolist())
417+
a1.append(elem.tensor.tolist())
430418
else:
431419
a1.append(elem)
432420
obj = a1
433421

434422
# is obj an ndarray already?
435-
base = None
436423
if isinstance(obj, ndarray):
437-
obj = obj._tensor
438-
base = obj
424+
obj = obj.tensor
439425

440426
# is a specific dtype requrested?
441427
torch_dtype = None
442428
if dtype is not None:
443429
torch_dtype = _dtypes.dtype(dtype).torch_dtype
444-
base = None
445430

446431
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
447-
return ndarray._from_tensor_and_base(tensor, base)
432+
return ndarray(tensor)
448433

449434

450435
def asarray(a, dtype=None, order=None, *, like=None):
@@ -453,10 +438,6 @@ def asarray(a, dtype=None, order=None, *, like=None):
453438
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
454439

455440

456-
def maybe_set_base(tensor, base):
457-
return ndarray._from_tensor_and_base(tensor, base)
458-
459-
460441
###### dtype routines
461442

462443

torch_np/_wrapper.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010

11-
from . import _decorators, _dtypes, _funcs, _helpers
11+
from . import _funcs, _helpers
1212
from ._detail import _dtypes_impl, _flips, _reductions, _util
1313
from ._detail import implementations as _impl
14-
from ._ndarray import array, asarray, maybe_set_base, ndarray
14+
from ._ndarray import asarray
1515
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
1616

1717
# Things to decide on (punt for now)
@@ -169,39 +169,34 @@ def stack(
169169
return _helpers.result_or_out(result, out)
170170

171171

172-
def array_split(ary, indices_or_sections, axis=0):
173-
tensor = asarray(ary).get()
174-
base = ary if isinstance(ary, ndarray) else None
175-
result = _impl.split_helper(tensor, indices_or_sections, axis)
176-
return tuple(maybe_set_base(x, base) for x in result)
172+
@normalizer
173+
def array_split(ary: ArrayLike, indices_or_sections, axis=0):
174+
result = _impl.split_helper(ary, indices_or_sections, axis)
175+
return _helpers.tuple_arrays_from(result)
177176

178177

179-
def split(ary, indices_or_sections, axis=0):
180-
tensor = asarray(ary).get()
181-
base = ary if isinstance(ary, ndarray) else None
182-
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
183-
return tuple(maybe_set_base(x, base) for x in result)
178+
@normalizer
179+
def split(ary: ArrayLike, indices_or_sections, axis=0):
180+
result = _impl.split_helper(ary, indices_or_sections, axis, strict=True)
181+
return _helpers.tuple_arrays_from(result)
184182

185183

186-
def hsplit(ary, indices_or_sections):
187-
tensor = asarray(ary).get()
188-
base = ary if isinstance(ary, ndarray) else None
189-
result = _impl.hsplit(tensor, indices_or_sections)
190-
return tuple(maybe_set_base(x, base) for x in result)
184+
@normalizer
185+
def hsplit(ary: ArrayLike, indices_or_sections):
186+
result = _impl.hsplit(ary, indices_or_sections)
187+
return _helpers.tuple_arrays_from(result)
191188

192189

193-
def vsplit(ary, indices_or_sections):
194-
tensor = asarray(ary).get()
195-
base = ary if isinstance(ary, ndarray) else None
196-
result = _impl.vsplit(tensor, indices_or_sections)
197-
return tuple(maybe_set_base(x, base) for x in result)
190+
@normalizer
191+
def vsplit(ary: ArrayLike, indices_or_sections):
192+
result = _impl.vsplit(ary, indices_or_sections)
193+
return _helpers.tuple_arrays_from(result)
198194

199195

200-
def dsplit(ary, indices_or_sections):
201-
tensor = asarray(ary).get()
202-
base = ary if isinstance(ary, ndarray) else None
203-
result = _impl.dsplit(tensor, indices_or_sections)
204-
return tuple(maybe_set_base(x, base) for x in result)
196+
@normalizer
197+
def dsplit(ary: ArrayLike, indices_or_sections):
198+
result = _impl.dsplit(ary, indices_or_sections)
199+
return _helpers.tuple_arrays_from(result)
205200

206201

207202
@normalizer

0 commit comments

Comments
 (0)