Skip to content

Commit f2513bf

Browse files
committed
API: remove ndarray.base
Base is tracked via `self._tensor._base`. Use `a.get()._base is b.get()` instead of numpy's `a.base is b`.
1 parent b3d5f0a commit f2513bf

File tree

5 files changed

+31
-52
lines changed

5 files changed

+31
-52
lines changed

torch_np/_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __new__(self, value):
4949
# and here we follow the second approach and create a new object
5050
# *for all inputs*.
5151
#
52-
return _ndarray.ndarray._from_tensor_and_base(tensor, None)
52+
return _ndarray.ndarray._from_tensor(tensor)
5353

5454

5555
##### these are abstract types

torch_np/_helpers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,11 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
8686

8787
def array_from(tensor, base=None):
8888
from ._ndarray import ndarray
89-
90-
base = base if isinstance(base, ndarray) else None
91-
return ndarray._from_tensor_and_base(tensor, base) # XXX: nuke .base
89+
return ndarray._from_tensor(tensor)
9290

9391

9492
def tuple_arrays_from(result):
9593
from ._ndarray import asarray
96-
9794
return tuple(asarray(x) for x in result)
9895

9996

torch_np/_ndarray.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,11 @@ def __getitem__(self, key):
6464
class ndarray:
6565
def __init__(self):
6666
self._tensor = torch.Tensor()
67-
self._base = None
6867

6968
@classmethod
70-
def _from_tensor_and_base(cls, tensor, base):
69+
def _from_tensor(cls, tensor):
7170
self = cls()
7271
self._tensor = tensor
73-
self._base = base
7472
return self
7573

7674
def get(self):
@@ -101,10 +99,6 @@ def strides(self):
10199
def itemsize(self):
102100
return self._tensor.element_size()
103101

104-
@property
105-
def base(self):
106-
return self._base
107-
108102
@property
109103
def flags(self):
110104
# Note contiguous in torch is assumed C-style
@@ -158,7 +152,7 @@ def copy(self, order="C"):
158152
if order != "C":
159153
raise NotImplementedError
160154
tensor = self._tensor.clone()
161-
return ndarray._from_tensor_and_base(tensor, None)
155+
return ndarray._from_tensor(tensor)
162156

163157
def tolist(self):
164158
return self._tensor.tolist()
@@ -398,7 +392,7 @@ def _upcast_int_indices(index):
398392
def __getitem__(self, index):
399393
index = _helpers.ndarrays_to_tensors(index)
400394
index = ndarray._upcast_int_indices(index)
401-
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
395+
return ndarray._from_tensor(self._tensor.__getitem__(index))
402396

403397
def __setitem__(self, index, value):
404398
index = _helpers.ndarrays_to_tensors(index)
@@ -432,19 +426,16 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
432426
obj = a1
433427

434428
# is obj an ndarray already?
435-
base = None
436429
if isinstance(obj, ndarray):
437-
obj = obj._tensor
438-
base = obj
430+
obj = obj.get()
439431

440432
# is a specific dtype requrested?
441433
torch_dtype = None
442434
if dtype is not None:
443435
torch_dtype = _dtypes.dtype(dtype).torch_dtype
444-
base = None
445436

446437
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
447-
return ndarray._from_tensor_and_base(tensor, base)
438+
return ndarray._from_tensor(tensor)
448439

449440

450441
def asarray(a, dtype=None, order=None, *, like=None):
@@ -453,10 +444,6 @@ def asarray(a, dtype=None, order=None, *, like=None):
453444
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
454445

455446

456-
def maybe_set_base(tensor, base):
457-
return ndarray._from_tensor_and_base(tensor, base)
458-
459-
460447
###### dtype routines
461448

462449

torch_np/_wrapper.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010

11-
from . import _decorators, _dtypes, _funcs, _helpers
11+
from . import _funcs, _helpers
1212
from ._detail import _dtypes_impl, _flips, _reductions, _util
1313
from ._detail import implementations as _impl
14-
from ._ndarray import array, asarray, maybe_set_base, ndarray
14+
from ._ndarray import asarray
1515
from ._normalizations import (
1616
ArrayLike,
1717
DTypeLike,
@@ -176,39 +176,34 @@ def stack(
176176
return _helpers.result_or_out(result, out)
177177

178178

179-
def array_split(ary, indices_or_sections, axis=0):
180-
tensor = asarray(ary).get()
181-
base = ary if isinstance(ary, ndarray) else None
182-
result = _impl.split_helper(tensor, indices_or_sections, axis)
183-
return tuple(maybe_set_base(x, base) for x in result)
179+
@normalizer
180+
def array_split(ary: ArrayLike, indices_or_sections, axis=0):
181+
result = _impl.split_helper(ary, indices_or_sections, axis)
182+
return _helpers.tuple_arrays_from(result)
184183

185184

186-
def split(ary, indices_or_sections, axis=0):
187-
tensor = asarray(ary).get()
188-
base = ary if isinstance(ary, ndarray) else None
189-
result = _impl.split_helper(tensor, indices_or_sections, axis, strict=True)
190-
return tuple(maybe_set_base(x, base) for x in result)
185+
@normalizer
186+
def split(ary: ArrayLike, indices_or_sections, axis=0):
187+
result = _impl.split_helper(ary, indices_or_sections, axis, strict=True)
188+
return _helpers.tuple_arrays_from(result)
191189

192190

193-
def hsplit(ary, indices_or_sections):
194-
tensor = asarray(ary).get()
195-
base = ary if isinstance(ary, ndarray) else None
196-
result = _impl.hsplit(tensor, indices_or_sections)
197-
return tuple(maybe_set_base(x, base) for x in result)
191+
@normalizer
192+
def hsplit(ary: ArrayLike, indices_or_sections):
193+
result = _impl.hsplit(ary, indices_or_sections)
194+
return _helpers.tuple_arrays_from(result)
198195

199196

200-
def vsplit(ary, indices_or_sections):
201-
tensor = asarray(ary).get()
202-
base = ary if isinstance(ary, ndarray) else None
203-
result = _impl.vsplit(tensor, indices_or_sections)
204-
return tuple(maybe_set_base(x, base) for x in result)
197+
@normalizer
198+
def vsplit(ary: ArrayLike, indices_or_sections):
199+
result = _impl.vsplit(ary, indices_or_sections)
200+
return _helpers.tuple_arrays_from(result)
205201

206202

207-
def dsplit(ary, indices_or_sections):
208-
tensor = asarray(ary).get()
209-
base = ary if isinstance(ary, ndarray) else None
210-
result = _impl.dsplit(tensor, indices_or_sections)
211-
return tuple(maybe_set_base(x, base) for x in result)
203+
@normalizer
204+
def dsplit(ary: ArrayLike, indices_or_sections):
205+
result = _impl.dsplit(ary, indices_or_sections)
206+
return _helpers.tuple_arrays_from(result)
212207

213208

214209
@normalizer

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_empty_tuple_index(self):
115115
# Empty tuple index creates a view
116116
a = np.array([1, 2, 3])
117117
assert_equal(a[()], a)
118-
assert_(a[()].base is a)
118+
assert_(a[()].get()._base is a.get())
119119
a = np.array(0)
120120
pytest.skip(
121121
"torch doesn't have scalar types with distinct instancing behaviours"
@@ -164,7 +164,7 @@ def test_ellipsis_index(self):
164164
assert_(a[...] is not a)
165165
assert_equal(a[...], a)
166166
# `a[...]` was `a` in numpy <1.9.
167-
assert_(a[...].base is a)
167+
assert_(a[...].get()._base is a.get())
168168

169169
# Slicing with ellipsis can skip an
170170
# arbitrary number of dimensions

0 commit comments

Comments
 (0)