Skip to content

Commit 7af69c4

Browse files
committed
API: remove ndarray.base
Base is tracked via `self._tensor._base`. Use `a.get()._base is b.get()` instead of numpy's `a.base is b`.
1 parent 023d453 commit 7af69c4

File tree

5 files changed

+38
-51
lines changed

5 files changed

+38
-51
lines changed

torch_np/_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __new__(self, value):
4949
# and here we follow the second approach and create a new object
5050
# *for all inputs*.
5151
#
52-
return _ndarray.ndarray._from_tensor_and_base(tensor, None)
52+
return _ndarray.ndarray._from_tensor(tensor)
5353

5454

5555
##### these are abstract types

torch_np/_helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
8787
def array_from(tensor, base=None):
8888
from ._ndarray import ndarray
8989

90-
base = base if isinstance(base, ndarray) else None
91-
return ndarray._from_tensor_and_base(tensor, base) # XXX: nuke .base
90+
return ndarray._from_tensor(tensor)
9291

9392

9493
def tuple_arrays_from(result):

torch_np/_ndarray.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,11 @@ def __getitem__(self, key):
6464
class ndarray:
6565
def __init__(self):
6666
self._tensor = torch.Tensor()
67-
self._base = None
6867

6968
@classmethod
70-
def _from_tensor_and_base(cls, tensor, base):
69+
def _from_tensor(cls, tensor):
7170
self = cls()
7271
self._tensor = tensor
73-
self._base = base
7472
return self
7573

7674
def get(self):
@@ -101,10 +99,6 @@ def strides(self):
10199
def itemsize(self):
102100
return self._tensor.element_size()
103101

104-
@property
105-
def base(self):
106-
return self._base
107-
108102
@property
109103
def flags(self):
110104
# Note contiguous in torch is assumed C-style
@@ -158,7 +152,7 @@ def copy(self, order="C"):
158152
if order != "C":
159153
raise NotImplementedError
160154
tensor = self._tensor.clone()
161-
return ndarray._from_tensor_and_base(tensor, None)
155+
return ndarray._from_tensor(tensor)
162156

163157
def tolist(self):
164158
return self._tensor.tolist()
@@ -398,7 +392,7 @@ def _upcast_int_indices(index):
398392
def __getitem__(self, index):
399393
index = _helpers.ndarrays_to_tensors(index)
400394
index = ndarray._upcast_int_indices(index)
401-
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
395+
return ndarray._from_tensor(self._tensor.__getitem__(index))
402396

403397
def __setitem__(self, index, value):
404398
index = _helpers.ndarrays_to_tensors(index)
@@ -432,19 +426,16 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
432426
obj = a1
433427

434428
# is obj an ndarray already?
435-
base = None
436429
if isinstance(obj, ndarray):
437-
obj = obj._tensor
438-
base = obj
430+
obj = obj.get()
439431

440432
# is a specific dtype requrested?
441433
torch_dtype = None
442434
if dtype is not None:
443435
torch_dtype = _dtypes.dtype(dtype).torch_dtype
444-
base = None
445436

446437
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
447-
return ndarray._from_tensor_and_base(tensor, base)
438+
return ndarray._from_tensor(tensor)
448439

449440

450441
def asarray(a, dtype=None, order=None, *, like=None):
@@ -453,10 +444,6 @@ def asarray(a, dtype=None, order=None, *, like=None):
453444
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
454445

455446

456-
def maybe_set_base(tensor, base):
457-
return ndarray._from_tensor_and_base(tensor, base)
458-
459-
460447
###### dtype routines
461448

462449

torch_np/_wrapper.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88

99
import torch
1010

11-
from . import _decorators, _dtypes, _funcs, _helpers
11+
from . import _funcs, _helpers
1212
from ._detail import _dtypes_impl, _flips, _reductions, _util
1313
from ._detail import implementations as _impl
14-
from ._ndarray import array, asarray, maybe_set_base, ndarray
15-
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
14+
from ._ndarray import asarray
15+
from ._normalizations import (
16+
ArrayLike,
17+
DTypeLike,
18+
NDArray,
19+
SubokLike,
20+
normalizer,
21+
)
1622

1723
# Things to decide on (punt for now)
1824
#
@@ -169,39 +175,34 @@ def stack(
169175
return _helpers.result_or_out(result, out)
170176

171177

172-
def array_split(ary, indices_or_sections, axis=0):
173-
tensor = asarray(ary).get()
174-
base = ary if isinstance(ary, ndarray) else None
175-
result = _impl.split_helper(tensor, indices_or_sections, axis)
176-
return tuple(maybe_set_base(x, base) for x in result)
178+
@normalizer
179+
def array_split(ary: ArrayLike, indices_or_sections, axis=0):
180+
result = _impl.split_helper(ary, indices_or_sections, axis)
181+
return _helpers.tuple_arrays_from(result)
177182

178183

179-
def split(ary, indices_or_sections, axis=0):
180-
tensor = asarray(ary).get()
181-
base = ary if isinstance(ary, ndarray) else None
182-
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
183-
return tuple(maybe_set_base(x, base) for x in result)
184+
@normalizer
185+
def split(ary: ArrayLike, indices_or_sections, axis=0):
186+
result = _impl.split_helper(ary, indices_or_sections, axis, strict=True)
187+
return _helpers.tuple_arrays_from(result)
184188

185189

186-
def hsplit(ary, indices_or_sections):
187-
tensor = asarray(ary).get()
188-
base = ary if isinstance(ary, ndarray) else None
189-
result = _impl.hsplit(tensor, indices_or_sections)
190-
return tuple(maybe_set_base(x, base) for x in result)
190+
@normalizer
191+
def hsplit(ary: ArrayLike, indices_or_sections):
192+
result = _impl.hsplit(ary, indices_or_sections)
193+
return _helpers.tuple_arrays_from(result)
191194

192195

193-
def vsplit(ary, indices_or_sections):
194-
tensor = asarray(ary).get()
195-
base = ary if isinstance(ary, ndarray) else None
196-
result = _impl.vsplit(tensor, indices_or_sections)
197-
return tuple(maybe_set_base(x, base) for x in result)
196+
@normalizer
197+
def vsplit(ary: ArrayLike, indices_or_sections):
198+
result = _impl.vsplit(ary, indices_or_sections)
199+
return _helpers.tuple_arrays_from(result)
198200

199201

200-
def dsplit(ary, indices_or_sections):
201-
tensor = asarray(ary).get()
202-
base = ary if isinstance(ary, ndarray) else None
203-
result = _impl.dsplit(tensor, indices_or_sections)
204-
return tuple(maybe_set_base(x, base) for x in result)
202+
@normalizer
203+
def dsplit(ary: ArrayLike, indices_or_sections):
204+
result = _impl.dsplit(ary, indices_or_sections)
205+
return _helpers.tuple_arrays_from(result)
205206

206207

207208
@normalizer

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_empty_tuple_index(self):
115115
# Empty tuple index creates a view
116116
a = np.array([1, 2, 3])
117117
assert_equal(a[()], a)
118-
assert_(a[()].base is a)
118+
assert_(a[()].get()._base is a.get())
119119
a = np.array(0)
120120
pytest.skip(
121121
"torch doesn't have scalar types with distinct instancing behaviours"
@@ -164,7 +164,7 @@ def test_ellipsis_index(self):
164164
assert_(a[...] is not a)
165165
assert_equal(a[...], a)
166166
# `a[...]` was `a` in numpy <1.9.
167-
assert_(a[...].base is a)
167+
assert_(a[...].get()._base is a.get())
168168

169169
# Slicing with ellipsis can skip an
170170
# arbitrary number of dimensions

0 commit comments

Comments
 (0)